diff --git a/Cargo.lock b/Cargo.lock index 2ac5975de..6ed71ff32 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1111,6 +1111,12 @@ version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56254986775e3233ffa9c4d7d3faaf6d36a2c09d30b20687e9f88bc8bafc16c8" +[[package]] +name = "difflib" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8" + [[package]] name = "digest" version = "0.8.1" @@ -1203,6 +1209,12 @@ dependencies = [ "winapi 0.3.9", ] +[[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + [[package]] name = "dunce" version = "1.0.4" @@ -1934,6 +1946,15 @@ dependencies = [ "miniz_oxide", ] +[[package]] +name = "float-cmp" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98de4bbd547a563b716d8dfa9aad1cb19bfab00f4fa09a6a4ed21dbcf44ce9c4" +dependencies = [ + "num-traits", +] + [[package]] name = "fnv" version = "1.0.7" @@ -1964,6 +1985,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fragile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" + [[package]] name = "fs2" version = "0.4.3" @@ -2899,15 +2926,6 @@ version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" -[[package]] -name = "matchers" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" -dependencies = [ - "regex-automata 0.1.10", -] - [[package]] name = "matches" version = "0.1.10" @@ -3080,6 +3098,33 @@ dependencies = [ "xmtp_proto", ] +[[package]] +name = "mockall" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c84490118f2ee2d74570d114f3d0493cbf02790df303d2707606c3e14e07c96" +dependencies = [ + "cfg-if 1.0.0", + "downcast", + "fragile", + "lazy_static", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ce75669015c4f47b289fd4d4f56e894e4c96003ffdf3ac51313126f94c6cbb" +dependencies = [ + "cfg-if 1.0.0", + "proc-macro2", + "quote", + "syn 1.0.109", +] + [[package]] name = "multimap" version = "0.8.3" @@ -3132,14 +3177,10 @@ dependencies = [ ] [[package]] -name = "nu-ansi-term" -version = "0.46.0" +name = "normalize-line-endings" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" -dependencies = [ - "overload", - "winapi 0.3.9", -] +checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" [[package]] name = "num-bigint" @@ -3490,12 +3531,6 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" -[[package]] -name = "overload" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" - [[package]] name = "p256" version = "0.13.2" @@ -3904,6 +3939,36 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" +[[package]] +name = "predicates" +version = "2.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59230a63c37f3e18569bdb90e4a89cbf5bf8b06fea0b84e65ea10cc4df47addd" +dependencies = [ + "difflib", + "float-cmp", + "itertools 0.10.5", + "normalize-line-endings", + "predicates-core", + "regex", +] + +[[package]] +name = "predicates-core" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b794032607612e7abeb4db69adb4e33590fa6cf1149e95fd7cb00e634b92f174" + +[[package]] +name = "predicates-tree" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368ba315fb8c5052ab692e68a0eefec6ec57b23a36959c14496f0b0df2c0cecf" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "prettyplease" version = "0.1.25" @@ -4277,19 +4342,10 @@ checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.3", + "regex-automata", "regex-syntax 0.8.2", ] -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", -] - [[package]] name = "regex-automata" version = "0.4.3" @@ -4301,12 +4357,6 @@ dependencies = [ "regex-syntax 0.8.2", ] -[[package]] -name = "regex-syntax" -version = "0.6.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - [[package]] name = "regex-syntax" version = "0.7.2" @@ -4949,15 +4999,6 @@ dependencies = [ "keccak", ] -[[package]] -name = "sharded-slab" -version = "0.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f40ca3c46823713e0d4209592e8d6e826aa57e928f09752619fc696c499637f6" -dependencies = [ - "lazy_static", -] - [[package]] name = "signal-hook-registry" version = "1.4.1" @@ -5291,6 +5332,12 @@ dependencies = [ "phf_codegen", ] +[[package]] +name = "termtree" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" + [[package]] name = "thiserror" version = "1.0.50" @@ -5311,16 +5358,6 @@ dependencies = [ "syn 2.0.38", ] -[[package]] -name = "thread_local" -version = "1.1.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdd6f064ccff2d6567adcb3873ca630700f00b5ad3f060c25b5dcfd9a4ce152" -dependencies = [ - "cfg-if 1.0.0", - "once_cell", -] - [[package]] name = "time" version = "0.1.45" @@ -5664,7 +5701,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0955b8137a1df6f1a2e9a37d8a6656291ff0297c1a97c24e0d8425fe2312f79a" dependencies = [ "once_cell", - "valuable", ] [[package]] @@ -5677,35 +5713,6 @@ dependencies = [ "tracing", ] -[[package]] -name = "tracing-log" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f751112709b4e791d8ce53e32c4ed2d353565a795ce84da2285393f41557bdf2" -dependencies = [ - "log", - "once_cell", - "tracing-core", -] - -[[package]] -name = "tracing-subscriber" -version = "0.3.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30a651bc37f915e81f087d86e62a18eec5f79550c7faff886f7090b4ea757c77" -dependencies = [ - "matchers", - "nu-ansi-term", - "once_cell", - "regex", - "sharded-slab", - "smallvec", - "thread_local", - "tracing", - "tracing-core", - "tracing-log", -] - [[package]] name = "try-lock" version = "0.2.4" @@ -5837,12 +5844,6 @@ dependencies = [ "rand 0.8.5", ] -[[package]] -name = "valuable" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b7e5d4d90034032940e4ace0d9a9a057e7a45cd94e6c007832e39edb82f6d" - [[package]] name = "vcpkg" version = "0.2.15" @@ -6487,6 +6488,7 @@ dependencies = [ "hex", "libsqlite3-sys", "log", + "mockall", "openmls 0.5.0 (git+https://github.com/xmtp/openmls)", "openmls_basic_credential 0.2.0 (git+https://github.com/xmtp/openmls)", "openmls_rust_crypto 0.2.0 (git+https://github.com/xmtp/openmls)", @@ -6497,10 +6499,9 @@ dependencies = [ "serde_json", "tempfile", "thiserror", + "tls_codec", "tokio", "toml 0.7.4", - "tracing", - "tracing-subscriber", "xmtp_api_grpc", "xmtp_cryptography", "xmtp_proto", diff --git a/dev/docker/up b/dev/docker/up index 0be0d8c39..e317f1f10 100755 --- a/dev/docker/up +++ b/dev/docker/up @@ -3,4 +3,4 @@ set -eou pipefail script_dir="$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )" "${script_dir}"/compose pull -"${script_dir}"/compose up -d +"${script_dir}"/compose up -d --build diff --git a/xmtp_mls/Cargo.toml b/xmtp_mls/Cargo.toml index 595fad9df..f25854025 100644 --- a/xmtp_mls/Cargo.toml +++ b/xmtp_mls/Cargo.toml @@ -19,8 +19,8 @@ native = ["libsqlite3-sys/bundled-sqlcipher-vendored-openssl"] [dependencies] anyhow = "1.0.71" async-trait = "0.1.68" -diesel = { version = "2.0.4", features = ["sqlite", "r2d2", "returning_clauses_for_sqlite_3_35"] } -diesel_migrations = { version = "2.0.0", features = ["sqlite"] } +diesel = { version = "2.1.3", features = ["sqlite", "r2d2", "returning_clauses_for_sqlite_3_35"] } +diesel_migrations = { version = "2.1.0", features = ["sqlite"] } ethers = "2.0.4" ethers-core = "2.0.4" futures = "0.3.28" @@ -41,8 +41,10 @@ tracing = "0.1.37" toml = "0.7.4" xmtp_cryptography = { path = "../xmtp_cryptography"} xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] } +tls_codec = "0.3.0" [dev-dependencies] +mockall = "0.11.4" rand = "0.8.5" tempfile = "3.5.0" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } diff --git a/xmtp_mls/migrations/2023-10-29-205333_state_machine_init/up.sql b/xmtp_mls/migrations/2023-10-29-205333_state_machine_init/up.sql index c90025c1c..7aadc7b03 100644 --- a/xmtp_mls/migrations/2023-10-29-205333_state_machine_init/up.sql +++ b/xmtp_mls/migrations/2023-10-29-205333_state_machine_init/up.sql @@ -14,8 +14,8 @@ CREATE INDEX groups_membership_state ON groups(membership_state); -- Successfully processed messages meant to be returned to the user CREATE TABLE group_messages ( - "id" BLOB PRIMARY KEY NOT NULL, -- Derived via SHA256(CONCAT(decrypted_message_bytes, conversation_id, timestamp)) + "id" BLOB PRIMARY KEY NOT NULL, "group_id" BLOB NOT NULL, -- Message contents after decryption "decrypted_message_bytes" BLOB NOT NULL, @@ -49,24 +49,10 @@ CREATE TABLE group_intents ( -- INTENT_STATE, "state" INT NOT NULL, -- The hash of the encrypted, concrete, form of the message if it was published. - "message_hash" BLOB, - FOREIGN KEY (group_id) REFERENCES groups(id) -); - -CREATE INDEX group_intents_group_id_id ON group_intents(group_id, id); - -CREATE TABLE outbound_welcome_messages ( - -- Derived via SHA256(CONCAT(group_id, welcome_message, installation_id)) - "id" BLOB PRIMARY KEY NOT NULL, - -- OUTBOUND_WELCOME_STATE - "state" INT NOT NULL, - "installation_id" BLOB NOT NULL, - -- The hash of the commit message which created this welcome - "commit_hash" BLOB NOT NULL, - -- The group this welcome belongs to - "group_id" BLOB NOT NULL, - "welcome_message" BLOB NOT NULL, + "payload_hash" BLOB, + -- (Optional) data needed for the post-commit flow. For example, welcome messages + "post_commit_data" BLOB, FOREIGN KEY (group_id) REFERENCES groups(id) ); -CREATE INDEX outbound_welcome_messages_commit_hash ON outbound_welcome_messages(commit_hash, state); \ No newline at end of file +CREATE INDEX group_intents_group_id_id ON group_intents(group_id, id); \ No newline at end of file diff --git a/xmtp_mls/src/api_client_wrapper.rs b/xmtp_mls/src/api_client_wrapper.rs index cd74ab16b..49f0c5511 100644 --- a/xmtp_mls/src/api_client_wrapper.rs +++ b/xmtp_mls/src/api_client_wrapper.rs @@ -1,29 +1,32 @@ use std::collections::HashMap; use xmtp_proto::{ - api_client::{Error as ApiError, ErrorKind, XmtpApiClient, XmtpMlsClient}, - xmtp::mls::message_contents::{ - group_message::{Version as GroupMessageVersion, V1 as GroupMessageV1}, - welcome_message::{Version as WelcomeMessageVersion, V1 as WelcomeMessageV1}, - WelcomeMessage as WelcomeMessageProto, + api_client::{ + Envelope, Error as ApiError, ErrorKind, PagingInfo, QueryRequest, XmtpApiClient, + XmtpMlsClient, }, xmtp::{ - message_api::v3::{ - publish_welcomes_request::WelcomeMessageRequest, ConsumeKeyPackagesRequest, - KeyPackageUpload, PublishToGroupRequest, PublishWelcomesRequest, - RegisterInstallationRequest, UploadKeyPackagesRequest, + message_api::{v1::Cursor, v3::GetIdentityUpdatesRequest}, + mls::message_contents::{ + group_message::{Version as GroupMessageVersion, V1 as GroupMessageV1}, + welcome_message::{Version as WelcomeMessageVersion, V1 as WelcomeMessageV1}, + WelcomeMessage as WelcomeMessageProto, + }, + }, + xmtp::{ + message_api::{ + v1::SortDirection, + v3::{ + get_identity_updates_response::update::Kind as UpdateKind, + publish_welcomes_request::WelcomeMessageRequest, ConsumeKeyPackagesRequest, + KeyPackageUpload, PublishToGroupRequest, PublishWelcomesRequest, + RegisterInstallationRequest, UploadKeyPackagesRequest, + }, }, mls::message_contents::GroupMessage, }, }; -pub struct WelcomeMessage { - pub(crate) ciphertext: Vec, - pub(crate) installation_id: Vec, -} - -type KeyPackageMap = HashMap, Vec>; - #[derive(Debug)] pub struct ApiClientWrapper { api_client: ApiClient, @@ -37,9 +40,49 @@ where Self { api_client } } + pub async fn read_topic( + &self, + topic: &str, + start_time_ns: u64, + ) -> Result, ApiError> { + let mut cursor: Option = None; + let mut out: Vec = vec![]; + let page_size = 100; + loop { + let mut result = self + .api_client + .query(QueryRequest { + content_topics: vec![topic.to_string()], + start_time_ns, + end_time_ns: 0, + paging_info: Some(PagingInfo { + cursor, + limit: page_size, + direction: SortDirection::Ascending as i32, + }), + }) + .await?; + + let num_envelopes = result.envelopes.len(); + out.append(&mut result.envelopes); + + if num_envelopes < page_size as usize || result.paging_info.is_none() { + break; + } + + cursor = result.paging_info.expect("Empty paging info").cursor; + + if cursor.is_none() { + break; + } + } + + Ok(out) + } + pub async fn register_installation( &self, - last_resort_key_package: &[u8], + last_resort_key_package: Vec, ) -> Result, ApiError> { let res = self .api_client @@ -53,13 +96,13 @@ where Ok(res.installation_id) } - pub async fn upload_key_packages(&self, key_packages: Vec<&[u8]>) -> Result<(), ApiError> { + pub async fn upload_key_packages(&self, key_packages: Vec>) -> Result<(), ApiError> { self.api_client .upload_key_packages(UploadKeyPackagesRequest { key_packages: key_packages .into_iter() .map(|kp| KeyPackageUpload { - key_package_tls_serialized: kp.to_vec(), + key_package_tls_serialized: kp, }) .collect(), }) @@ -70,12 +113,12 @@ where pub async fn consume_key_packages( &self, - installation_ids: Vec<&[u8]>, + installation_ids: Vec>, ) -> Result { let res = self .api_client .consume_key_packages(ConsumeKeyPackagesRequest { - installation_ids: installation_ids.iter().map(|id| id.to_vec()).collect(), + installation_ids: installation_ids.clone(), }) .await?; @@ -124,6 +167,61 @@ where Ok(()) } + pub async fn get_identity_updates( + &self, + start_time_ns: u64, + wallet_addresses: Vec, + ) -> Result { + let result = self + .api_client + .get_identity_updates(GetIdentityUpdatesRequest { + start_time_ns, + wallet_addresses: wallet_addresses.clone(), + }) + .await?; + + if result.updates.len() != wallet_addresses.len() { + println!("mismatched number of results"); + return Err(ApiError::new(ErrorKind::MlsError)); + } + + let mapping: IdentityUpdatesMap = result + .updates + .into_iter() + .zip(wallet_addresses.into_iter()) + .map(|(update, wallet_address)| { + ( + wallet_address, + update + .updates + .into_iter() + .map(|update| match update.kind { + Some(UpdateKind::NewInstallation(new_installation)) => { + IdentityUpdate::NewInstallation(NewInstallation { + timestamp_ns: update.timestamp_ns, + installation_id: new_installation.installation_id, + credential_bytes: new_installation.credential_identity, + }) + } + Some(UpdateKind::RevokedInstallation(revoke_installation)) => { + IdentityUpdate::RevokeInstallation(RevokeInstallation { + timestamp_ns: update.timestamp_ns, + installation_id: revoke_installation.installation_id, + }) + } + None => { + println!("no update kind"); + IdentityUpdate::Invalid + } + }) + .collect(), + ) + }) + .collect(); + + Ok(mapping) + } + pub async fn publish_to_group(&self, group_messages: Vec<&[u8]>) -> Result<(), ApiError> { let to_send: Vec = group_messages .iter() @@ -141,3 +239,379 @@ where Ok(()) } } + +#[derive(Debug, PartialEq)] +pub struct WelcomeMessage { + pub(crate) ciphertext: Vec, + pub(crate) installation_id: Vec, +} + +#[derive(Debug, PartialEq)] +pub struct NewInstallation { + pub installation_id: Vec, + pub credential_bytes: Vec, + pub timestamp_ns: u64, +} + +#[derive(Debug, PartialEq)] +pub struct RevokeInstallation { + pub installation_id: Vec, // TODO: Add proof of revocation + pub timestamp_ns: u64, +} + +#[derive(Debug, PartialEq)] +pub enum IdentityUpdate { + NewInstallation(NewInstallation), + RevokeInstallation(RevokeInstallation), + Invalid, +} + +type KeyPackageMap = HashMap, Vec>; + +type IdentityUpdatesMap = HashMap>; + +#[cfg(test)] +mod tests { + use super::ApiClientWrapper; + use mockall::mock; + use xmtp_proto::api_client::{ + Error, PagingInfo, XmtpApiClient, XmtpApiSubscription, XmtpMlsClient, + }; + use xmtp_proto::xmtp::message_api::v1::IndexCursor; + use xmtp_proto::xmtp::message_api::v3::consume_key_packages_response::KeyPackage; + use xmtp_proto::xmtp::message_api::v3::get_identity_updates_response::update::Kind as UpdateKind; + use xmtp_proto::xmtp::message_api::v3::get_identity_updates_response::{ + NewInstallationUpdate, Update, WalletUpdates, + }; + use xmtp_proto::xmtp::message_api::v3::{ + ConsumeKeyPackagesRequest, ConsumeKeyPackagesResponse, GetIdentityUpdatesRequest, + GetIdentityUpdatesResponse, PublishToGroupRequest, PublishWelcomesRequest, + RegisterInstallationRequest, RegisterInstallationResponse, UploadKeyPackagesRequest, + }; + + use xmtp_proto::xmtp::message_api::v1::{ + cursor::Cursor as InnerCursor, BatchQueryRequest, BatchQueryResponse, Cursor, Envelope, + PublishRequest, PublishResponse, QueryRequest, QueryResponse, SubscribeRequest, + }; + + use async_trait::async_trait; + + fn build_envelopes(num_envelopes: usize, topic: &str) -> Vec { + let mut out: Vec = vec![]; + for i in 0..num_envelopes { + out.push(Envelope { + content_topic: topic.to_string(), + message: vec![i as u8], + timestamp_ns: i as u64, + }) + } + out + } + + mock! { + pub Subscription {} + + impl XmtpApiSubscription for Subscription { + fn is_closed(&self) -> bool; + fn get_messages(&self) -> Vec; + fn close_stream(&mut self); + } + } + + // Create a mock XmtpClient for testing the client wrapper + mock! { + pub ApiClient {} + + #[async_trait] + impl XmtpMlsClient for ApiClient { + async fn register_installation( + &self, + request: RegisterInstallationRequest, + ) -> Result; + async fn upload_key_packages(&self, request: UploadKeyPackagesRequest) -> Result<(), Error>; + async fn consume_key_packages( + &self, + request: ConsumeKeyPackagesRequest, + ) -> Result; + async fn publish_to_group(&self, request: PublishToGroupRequest) -> Result<(), Error>; + async fn publish_welcomes(&self, request: PublishWelcomesRequest) -> Result<(), Error>; + async fn get_identity_updates( + &self, + request: GetIdentityUpdatesRequest, + ) -> Result; + } + + #[async_trait] + impl XmtpApiClient for ApiClient { + // Need to set an associated type and don't currently need streaming + // Can figure out a mocked stream type later + type Subscription = MockSubscription; + + fn set_app_version(&mut self, version: String); + + async fn publish( + &self, + token: String, + request: PublishRequest, + ) -> Result; + + async fn subscribe(&self, request: SubscribeRequest) -> Result<::Subscription, Error>; + + async fn query(&self, request: QueryRequest) -> Result; + + async fn batch_query(&self, request: BatchQueryRequest) -> Result; + } + + + } + + #[tokio::test] + async fn test_register_installation() { + let mut mock_api = MockApiClient::new(); + mock_api.expect_register_installation().returning(move |_| { + Ok(RegisterInstallationResponse { + installation_id: vec![1, 2, 3], + }) + }); + let wrapper = ApiClientWrapper::new(mock_api); + let result = wrapper.register_installation(vec![2, 3, 4]).await.unwrap(); + assert_eq!(result, vec![1, 2, 3]); + } + + #[tokio::test] + async fn test_upload_key_packages() { + let mut mock_api = MockApiClient::new(); + let key_package = vec![1, 2, 3]; + // key_package gets moved below but needs to be used for assertions later + let key_package_clone = key_package.clone(); + mock_api + .expect_upload_key_packages() + .withf(move |req| { + req.key_packages[0] + .key_package_tls_serialized + .eq(&key_package) + }) + .returning(move |_| Ok(())); + let wrapper = ApiClientWrapper::new(mock_api); + let result = wrapper.upload_key_packages(vec![key_package_clone]).await; + assert!(result.is_ok()); + } + + #[tokio::test] + async fn test_consume_key_packages() { + let mut mock_api = MockApiClient::new(); + let installation_ids: Vec> = vec![vec![1, 2, 3], vec![4, 5, 6]]; + mock_api.expect_consume_key_packages().returning(move |_| { + Ok(ConsumeKeyPackagesResponse { + key_packages: vec![ + KeyPackage { + key_package_tls_serialized: vec![7, 8, 9], + }, + KeyPackage { + key_package_tls_serialized: vec![10, 11, 12], + }, + ], + }) + }); + let wrapper = ApiClientWrapper::new(mock_api); + let result = wrapper + .consume_key_packages(installation_ids.clone()) + .await + .unwrap(); + assert_eq!(result.len(), 2); + + for (k, v) in result { + if k.eq(&installation_ids[0]) { + assert_eq!(v, vec![7, 8, 9]); + } else { + assert_eq!(v, vec![10, 11, 12]); + } + } + } + + #[tokio::test] + async fn test_get_identity_updates() { + let mut mock_api = MockApiClient::new(); + let start_time_ns = 12; + let wallet_addresses = vec!["wallet1".to_string(), "wallet2".to_string()]; + // wallet_addresses gets moved below but needs to be used for assertions later + let wallet_addresses_clone = wallet_addresses.clone(); + mock_api + .expect_get_identity_updates() + .withf(move |req| { + req.start_time_ns.eq(&start_time_ns) && req.wallet_addresses.eq(&wallet_addresses) + }) + .returning(move |_| { + Ok(GetIdentityUpdatesResponse { + updates: { + vec![ + WalletUpdates { + updates: vec![Update { + timestamp_ns: 1, + kind: Some(UpdateKind::NewInstallation( + NewInstallationUpdate { + installation_id: vec![1, 2, 3], + credential_identity: vec![4, 5, 6], + }, + )), + }], + }, + WalletUpdates { + updates: vec![Update { + timestamp_ns: 2, + kind: Some(UpdateKind::NewInstallation( + NewInstallationUpdate { + installation_id: vec![7, 8, 9], + credential_identity: vec![10, 11, 12], + }, + )), + }], + }, + ] + }, + }) + }); + + let wrapper = ApiClientWrapper::new(mock_api); + let result = wrapper + .get_identity_updates(start_time_ns, wallet_addresses_clone.clone()) + .await + .unwrap(); + assert_eq!(result.len(), 2); + + for (k, v) in result { + if k.eq(&wallet_addresses_clone[0]) { + assert_eq!(v.len(), 1); + assert_eq!( + v[0], + super::IdentityUpdate::NewInstallation(super::NewInstallation { + installation_id: vec![1, 2, 3], + credential_bytes: vec![4, 5, 6], + timestamp_ns: 1, + }) + ); + } else { + assert_eq!(v.len(), 1); + assert_eq!( + v[0], + super::IdentityUpdate::NewInstallation(super::NewInstallation { + installation_id: vec![7, 8, 9], + credential_bytes: vec![10, 11, 12], + timestamp_ns: 2, + }) + ); + } + } + } + + #[tokio::test] + async fn test_read_topic_single_page() { + let mut mock_api = MockApiClient::new(); + let topic = "topic"; + let start_time_ns = 10; + // Set expectation for first request with no cursor + mock_api.expect_query().returning(move |req| { + assert_eq!(req.content_topics[0], topic); + + Ok(QueryResponse { + paging_info: Some(PagingInfo { + cursor: None, + limit: 100, + direction: 0, + }), + envelopes: build_envelopes(10, topic), + }) + }); + + let wrapper = ApiClientWrapper::new(mock_api); + + let result = wrapper.read_topic(topic, start_time_ns).await.unwrap(); + assert_eq!(result.len(), 10); + } + + #[tokio::test] + async fn test_read_topic_single_page_exactly_100_results() { + let mut mock_api = MockApiClient::new(); + let topic = "topic"; + let start_time_ns = 10; + // Set expectation for first request with no cursor + mock_api.expect_query().times(1).returning(move |req| { + assert_eq!(req.content_topics[0], topic); + + Ok(QueryResponse { + paging_info: Some(PagingInfo { + cursor: None, + limit: 100, + direction: 0, + }), + envelopes: build_envelopes(100, topic), + }) + }); + + let wrapper = ApiClientWrapper::new(mock_api); + + let result = wrapper.read_topic(topic, start_time_ns).await.unwrap(); + assert_eq!(result.len(), 100); + } + + #[tokio::test] + async fn test_read_topic_multi_page() { + let mut mock_api = MockApiClient::new(); + let topic = "topic"; + let start_time_ns = 10; + // Set expectation for first request with no cursor + mock_api + .expect_query() + .withf(move |req| match req.paging_info.clone() { + Some(paging_info) => match paging_info.cursor { + Some(_) => false, + None => true, + }, + None => true, + } && req.start_time_ns == 10) + .returning(move |req| { + assert_eq!(req.content_topics[0], topic); + + Ok(QueryResponse { + paging_info: Some(PagingInfo { + cursor: Some(Cursor { + cursor: Some(InnerCursor::Index(IndexCursor { + digest: vec![], + sender_time_ns: 0, + })), + }), + limit: 100, + direction: 0, + }), + envelopes: build_envelopes(100, topic), + }) + }); + // Set expectation for requests with a cursor + mock_api + .expect_query() + .withf(|req| match req.paging_info.clone() { + Some(paging_info) => match paging_info.cursor { + Some(_) => true, + None => false, + }, + None => false, + }) + .returning(move |req| { + assert_eq!(req.content_topics[0], topic); + + Ok(QueryResponse { + paging_info: Some(PagingInfo { + cursor: None, + limit: 100, + direction: 0, + }), + envelopes: build_envelopes(100, topic), + }) + }); + + let wrapper = ApiClientWrapper::new(mock_api); + + let result = wrapper.read_topic(topic, start_time_ns).await.unwrap(); + assert_eq!(result.len(), 200); + } +} diff --git a/xmtp_mls/src/builder.rs b/xmtp_mls/src/builder.rs index 258474b97..415520a9e 100644 --- a/xmtp_mls/src/builder.rs +++ b/xmtp_mls/src/builder.rs @@ -69,7 +69,7 @@ where } Ok(identity) } - None => Ok(Identity::new(store, CIPHERSUITE, provider, &owner)?), + None => Ok(Identity::new(store, provider, &owner)?), }, #[cfg(test)] IdentityStrategy::ExternalIdentity(identity) => Ok(identity), @@ -189,16 +189,6 @@ mod tests { assert!(!client.installation_public_key().is_empty()); } - #[tokio::test] - async fn test_mls() { - let client = ClientBuilder::new_test_client(generate_local_wallet().into()).await; - let result = client.api_client.register_installation(&[1, 2, 3]).await; - - assert!(result.is_err()); - let error_string = result.err().unwrap().to_string(); - assert!(error_string.contains("invalid identity")); - } - #[tokio::test] async fn identity_persistence_test() { let tmpdb = TempPath::from_path("./db.db3") diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 24486d81f..ebf81c04c 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -1,11 +1,18 @@ +use std::collections::HashSet; + +use openmls::prelude::TlsSerializeTrait; use thiserror::Error; +use tls_codec::Error as TlsSerializationError; use xmtp_proto::api_client::{XmtpApiClient, XmtpMlsClient}; use crate::{ - api_client_wrapper::ApiClientWrapper, + api_client_wrapper::{ApiClientWrapper, IdentityUpdate}, + configuration::KEY_PACKAGE_TOP_UP_AMOUNT, identity::Identity, storage::{EncryptedMessageStore, StorageError}, types::Address, + verified_key_package::{KeyPackageVerificationError, VerifiedKeyPackage}, + xmtp_openmls_provider::XmtpOpenMlsProvider, }; #[derive(Clone, Copy, Default, Debug)] @@ -26,6 +33,12 @@ pub enum ClientError { Ddd(#[from] diesel::result::Error), #[error("Query failed: {0}")] QueryError(#[from] xmtp_proto::api_client::Error), + #[error("identity error: {0}")] + Identity(#[from] crate::identity::IdentityError), + #[error("serialization error: {0}")] + Serialization(#[from] TlsSerializationError), + #[error("key package verification: {0}")] + KeyPackageVerification(#[from] KeyPackageVerificationError), #[error("generic:{0}")] Generic(String), } @@ -68,6 +81,92 @@ where } } + // TODO: Remove this and figure out the correct lifetimes to allow long lived provider + fn mls_provider(&self) -> XmtpOpenMlsProvider { + XmtpOpenMlsProvider::new(&self.store) + } + + pub async fn register_identity(&self) -> Result<(), ClientError> { + // TODO: Mark key package as last_resort in creation + let last_resort_kp = self.identity.new_key_package(&self.mls_provider())?; + let last_resort_kp_bytes = last_resort_kp.tls_serialize_detached()?; + + self.api_client + .register_installation(last_resort_kp_bytes) + .await?; + + Ok(()) + } + + pub async fn top_up_key_packages(&self) -> Result<(), ClientError> { + let key_packages: Result>, ClientError> = (0..KEY_PACKAGE_TOP_UP_AMOUNT) + .map(|_| -> Result, ClientError> { + let kp = self.identity.new_key_package(&self.mls_provider())?; + let kp_bytes = kp.tls_serialize_detached()?; + + Ok(kp_bytes) + }) + .collect(); + + self.api_client.upload_key_packages(key_packages?).await?; + + Ok(()) + } + + async fn get_all_active_installation_ids( + &self, + wallet_addresses: Vec, + ) -> Result>, ClientError> { + let update_mapping = self + .api_client + .get_identity_updates(0, wallet_addresses) + .await?; + + let mut installation_ids: Vec> = vec![]; + + for (_, updates) in update_mapping { + let mut tmp: HashSet> = HashSet::new(); + for update in updates { + match update { + IdentityUpdate::Invalid => {} + IdentityUpdate::NewInstallation(new_installation) => { + // TODO: Validate credential + tmp.insert(new_installation.installation_id); + } + IdentityUpdate::RevokeInstallation(revoke_installation) => { + tmp.remove(&revoke_installation.installation_id); + } + } + } + installation_ids.extend(tmp); + } + + Ok(installation_ids) + } + + // Get a flat list of one key package per installation for all the wallet addresses provided. + // Revoked installations will be omitted from the list + pub async fn get_key_packages_for_wallet_addresses( + &self, + wallet_addresses: Vec, + ) -> Result, ClientError> { + let installation_ids = self + .get_all_active_installation_ids(wallet_addresses) + .await?; + + let key_package_results = self + .api_client + .consume_key_packages(installation_ids) + .await?; + + let mls_provider = self.mls_provider(); + + Ok(key_package_results + .values() + .map(|bytes| VerifiedKeyPackage::from_bytes(&mls_provider, bytes.as_slice())) + .collect::>()?) + } + pub fn account_address(&self) -> Address { self.identity.account_address.clone() } @@ -76,3 +175,66 @@ where self.identity.installation_keys.to_public_vec() } } + +#[cfg(test)] +mod tests { + use xmtp_cryptography::utils::generate_local_wallet; + + use crate::{builder::ClientBuilder, InboxOwner}; + + #[tokio::test] + async fn test_mls_error() { + let client = ClientBuilder::new_test_client(generate_local_wallet().into()).await; + let result = client.api_client.register_installation(vec![1, 2, 3]).await; + + assert!(result.is_err()); + let error_string = result.err().unwrap().to_string(); + assert!(error_string.contains("invalid identity")); + } + + #[tokio::test] + async fn test_register_installation() { + let wallet = generate_local_wallet(); + let client = ClientBuilder::new_test_client(wallet.clone().into()).await; + client.register_identity().await.unwrap(); + + // Make sure the installation is actually on the network + let installation_ids = client + .get_all_active_installation_ids(vec![wallet.get_address()]) + .await + .unwrap(); + assert_eq!(installation_ids.len(), 1); + } + + #[tokio::test] + async fn test_top_up_key_packages() { + let wallet = generate_local_wallet(); + let wallet_address = wallet.get_address(); + let client = ClientBuilder::new_test_client(wallet.clone().into()).await; + + client.register_identity().await.unwrap(); + client.top_up_key_packages().await.unwrap(); + + let key_packages = client + .get_key_packages_for_wallet_addresses(vec![wallet_address.clone()]) + .await + .unwrap(); + + assert_eq!(key_packages.len(), 1); + + let key_package = key_packages.first().unwrap(); + assert_eq!(key_package.wallet_address, wallet_address); + + let key_packages_2 = client + .get_key_packages_for_wallet_addresses(vec![wallet_address.clone()]) + .await + .unwrap(); + + assert_eq!(key_packages_2.len(), 1); + + // Ensure we got back different key packages + let key_package_2 = key_packages_2.first().unwrap(); + assert_eq!(key_package_2.wallet_address, wallet_address); + assert!(!(key_package_2.eq(&key_package))); + } +} diff --git a/xmtp_mls/src/configuration.rs b/xmtp_mls/src/configuration.rs index 2c5ce4b7d..577acc965 100644 --- a/xmtp_mls/src/configuration.rs +++ b/xmtp_mls/src/configuration.rs @@ -1,5 +1,10 @@ +use openmls::versions::ProtocolVersion; use openmls_traits::types::Ciphersuite; // TODO confirm ciphersuite choice pub const CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519; + +pub const MLS_PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::Mls10; + +pub const KEY_PACKAGE_TOP_UP_AMOUNT: u16 = 100; diff --git a/xmtp_mls/src/identity.rs b/xmtp_mls/src/identity.rs index f45e12af3..a9723a091 100644 --- a/xmtp_mls/src/identity.rs +++ b/xmtp_mls/src/identity.rs @@ -5,16 +5,14 @@ use openmls::{ versions::ProtocolVersion, }; use openmls_basic_credential::SignatureKeyPair; -use openmls_traits::{ - types::{Ciphersuite, CryptoError}, - OpenMlsProvider, -}; +use openmls_traits::{types::CryptoError, OpenMlsProvider}; use prost::Message; use thiserror::Error; use xmtp_cryptography::signature::SignatureError; use crate::{ association::{AssociationError, AssociationText, Eip191Association}, + configuration::CIPHERSUITE, storage::{identity::StoredIdentity, EncryptedMessageStore, StorageError}, types::Address, xmtp_openmls_provider::XmtpOpenMlsProvider, @@ -46,11 +44,10 @@ pub struct Identity { impl Identity { pub(crate) fn new( store: &EncryptedMessageStore, - ciphersuite: Ciphersuite, provider: &XmtpOpenMlsProvider, owner: &impl InboxOwner, ) -> Result { - let signature_keys = SignatureKeyPair::new(ciphersuite.signature_algorithm())?; + let signature_keys = SignatureKeyPair::new(CIPHERSUITE.signature_algorithm())?; signature_keys.store(provider.key_store())?; let credential = Identity::create_credential(&signature_keys, owner)?; @@ -59,7 +56,7 @@ impl Identity { // TODO: Make OpenMLS not delete this once used let _last_resort_key_package = KeyPackage::builder().build( CryptoConfig { - ciphersuite, + ciphersuite: CIPHERSUITE, version: ProtocolVersion::default(), }, provider, @@ -82,6 +79,26 @@ impl Identity { Ok(identity) } + pub(crate) fn new_key_package( + &self, + provider: &XmtpOpenMlsProvider, + ) -> Result { + let kp = KeyPackage::builder().build( + CryptoConfig { + ciphersuite: CIPHERSUITE, + version: ProtocolVersion::default(), + }, + provider, + &self.installation_keys, + CredentialWithKey { + credential: self.credential.clone(), + signature_key: self.installation_keys.to_public_vec().into(), + }, + )?; + + Ok(kp) + } + fn create_credential( installation_keys: &SignatureKeyPair, owner: &impl InboxOwner, @@ -106,19 +123,16 @@ impl Identity { mod tests { use xmtp_cryptography::utils::generate_local_wallet; - use crate::{ - configuration::CIPHERSUITE, storage::EncryptedMessageStore, - xmtp_openmls_provider::XmtpOpenMlsProvider, - }; + use crate::{storage::EncryptedMessageStore, xmtp_openmls_provider::XmtpOpenMlsProvider}; use super::Identity; #[test] fn does_not_error() { + let store = EncryptedMessageStore::default(); Identity::new( - &EncryptedMessageStore::default(), - CIPHERSUITE, - &XmtpOpenMlsProvider::new(&EncryptedMessageStore::default()), + &store, + &XmtpOpenMlsProvider::new(&store), &generate_local_wallet(), ) .unwrap(); diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index 4a9ec415f..7d757e3a8 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -9,6 +9,7 @@ pub mod owner; mod proto_wrapper; pub mod storage; pub mod types; +pub mod verified_key_package; mod xmtp_openmls_provider; pub use client::{Client, Network}; diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index 02620ba37..7da2954d9 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -10,7 +10,8 @@ pub struct StoredGroupIntent { pub state: i32, pub group_id: Vec, pub data: Vec, - pub message_hash: Option>, + pub payload_hash: Option>, + pub post_commit_data: Option>, } #[derive(Insertable, Debug, Clone)] diff --git a/xmtp_mls/src/storage/encrypted_store/group_message.rs b/xmtp_mls/src/storage/encrypted_store/group_message.rs index d3ebd781e..de6fd6b1c 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_message.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_message.rs @@ -22,3 +22,4 @@ pub struct StoredGroupMessage { } impl_fetch_and_store!(StoredGroupMessage, group_messages, Vec); + diff --git a/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs b/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs new file mode 100644 index 000000000..f2cb87e7a --- /dev/null +++ b/xmtp_mls/src/storage/encrypted_store/key_store_entry.rs @@ -0,0 +1,38 @@ +use super::DbConnection; +use super::{schema::openmls_key_store, StorageError}; +use crate::{Delete, Fetch, Store}; +use diesel::prelude::*; + +#[derive(Insertable, Queryable, Debug, Clone)] +#[diesel(table_name = openmls_key_store)] +#[diesel(primary_key(key_bytes))] +pub struct StoredKeyStoreEntry { + pub key_bytes: Vec, + pub value_bytes: Vec, +} + +impl Store for StoredKeyStoreEntry { + fn store(&self, into: &mut DbConnection) -> Result<(), StorageError> { + diesel::insert_into(openmls_key_store::table) + .values(self) + .execute(into)?; + + Ok(()) + } +} + +impl Fetch for DbConnection { + type Key = Vec; + fn fetch(&mut self, key: Vec) -> Result, StorageError> where { + use super::schema::openmls_key_store::dsl::*; + Ok(openmls_key_store.find(key).first(self).optional()?) + } +} + +impl Delete for DbConnection { + type Key = Vec; + fn delete(&mut self, key: Vec) -> Result where { + use super::schema::openmls_key_store::dsl::*; + Ok(diesel::delete(openmls_key_store.filter(key_bytes.eq(key))).execute(self)?) + } +} diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 7ea6b8091..1a33dfdf0 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -15,8 +15,7 @@ pub mod group; pub mod group_intent; pub mod group_message; pub mod identity; -pub mod key_store; -pub mod outbound_welcome_message; +pub mod key_store_entry; pub mod schema; pub mod topic_refresh_state; diff --git a/xmtp_mls/src/storage/encrypted_store/schema.rs b/xmtp_mls/src/storage/encrypted_store/schema.rs index 339e42ba7..33e251307 100644 --- a/xmtp_mls/src/storage/encrypted_store/schema.rs +++ b/xmtp_mls/src/storage/encrypted_store/schema.rs @@ -7,7 +7,8 @@ diesel::table! { group_id -> Binary, data -> Binary, state -> Integer, - message_hash -> Nullable, + payload_hash -> Nullable, + post_commit_data -> Nullable, } } @@ -17,6 +18,7 @@ diesel::table! { group_id -> Binary, decrypted_message_bytes -> Binary, sent_at_ns -> BigInt, + kind -> Integer, sender_installation_id -> Binary, sender_wallet_address -> Text, } @@ -46,17 +48,6 @@ diesel::table! { } } -diesel::table! { - outbound_welcome_messages (id) { - id -> Binary, - state -> Integer, - installation_id -> Binary, - commit_hash -> Binary, - group_id -> Binary, - welcome_message -> Binary, - } -} - diesel::table! { topic_refresh_state (topic) { topic -> Text, @@ -66,7 +57,6 @@ diesel::table! { diesel::joinable!(group_intents -> groups (group_id)); diesel::joinable!(group_messages -> groups (group_id)); -diesel::joinable!(outbound_welcome_messages -> groups (group_id)); diesel::allow_tables_to_appear_in_same_query!( group_intents, @@ -74,6 +64,5 @@ diesel::allow_tables_to_appear_in_same_query!( groups, identity, openmls_key_store, - outbound_welcome_messages, topic_refresh_state, ); diff --git a/xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs b/xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs index 8ec983525..46b8c13e0 100644 --- a/xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs @@ -1,5 +1,6 @@ use super::schema::topic_refresh_state; use crate::impl_fetch_and_store; + use diesel::prelude::*; #[derive(Insertable, Identifiable, Queryable, Debug, Clone)] @@ -11,3 +12,4 @@ pub struct TopicRefreshState { } impl_fetch_and_store!(TopicRefreshState, topic_refresh_state); + diff --git a/xmtp_mls/src/storage/mod.rs b/xmtp_mls/src/storage/mod.rs index 994ef2da8..b103bcc8f 100644 --- a/xmtp_mls/src/storage/mod.rs +++ b/xmtp_mls/src/storage/mod.rs @@ -4,8 +4,8 @@ mod serialization; pub mod sql_key_store; pub use encrypted_store::{ - group_intent, group_message, identity, key_store, outbound_welcome_message, - topic_refresh_state, DbConnection, EncryptedMessageStore, EncryptionKey, StorageOption, + group_intent, group_message, identity, key_store_entry, topic_refresh_state, DbConnection, + EncryptedMessageStore, EncryptionKey, StorageOption, }; pub use errors::StorageError; diff --git a/xmtp_mls/src/storage/sql_key_store.rs b/xmtp_mls/src/storage/sql_key_store.rs index 9b839b526..fb915bb3e 100644 --- a/xmtp_mls/src/storage/sql_key_store.rs +++ b/xmtp_mls/src/storage/sql_key_store.rs @@ -5,7 +5,7 @@ use std::borrow::Cow; use crate::{Delete, Fetch, Store}; use super::{ - encrypted_store::key_store::StoredKeyStoreEntry, + encrypted_store::key_store_entry::StoredKeyStoreEntry, serialization::{db_deserialize, db_serialize}, EncryptedMessageStore, StorageError, }; diff --git a/xmtp_mls/src/verified_key_package.rs b/xmtp_mls/src/verified_key_package.rs new file mode 100644 index 000000000..ab7f557be --- /dev/null +++ b/xmtp_mls/src/verified_key_package.rs @@ -0,0 +1,80 @@ +use openmls::prelude::{KeyPackage, KeyPackageIn, KeyPackageVerifyError}; +use openmls_traits::OpenMlsProvider; +use prost::{DecodeError, Message}; +use thiserror::Error; +use tls_codec::{Deserialize, Error as TlsSerializationError}; +use xmtp_proto::xmtp::v3::message_contents::Eip191Association as Eip191AssociationProto; + +use crate::{ + association::{AssociationError, Eip191Association}, + configuration::MLS_PROTOCOL_VERSION, + xmtp_openmls_provider::XmtpOpenMlsProvider, +}; + +#[derive(Debug, Error)] +pub enum KeyPackageVerificationError { + #[error("serialization error: {0}")] + Serialization(#[from] TlsSerializationError), + #[error("mls validation: {0}")] + MlsValidation(#[from] KeyPackageVerifyError), + #[error("association: {0}")] + Association(#[from] AssociationError), + #[error("decode: {0}")] + Decode(#[from] DecodeError), + #[error("generic: {0}")] + Generic(String), +} + +#[derive(Debug, Clone, PartialEq)] +pub struct VerifiedKeyPackage { + pub inner: KeyPackage, + pub wallet_address: String, +} + +impl VerifiedKeyPackage { + pub fn new(inner: KeyPackage, wallet_address: String) -> Self { + Self { + inner, + wallet_address, + } + } + + // Validates starting with a KeyPackage (which is already validated by OpenMLS) + pub fn from_key_package(kp: KeyPackage) -> Result { + let leaf_node = kp.leaf_node(); + let identity_bytes = leaf_node.credential().identity(); + let pub_key_bytes = leaf_node.signature_key().as_slice(); + let wallet_address = identity_to_wallet_address(identity_bytes, pub_key_bytes)?; + + Ok(Self::new(kp, wallet_address)) + } + + // Validates starting with a KeyPackageIn as bytes (which is not validated by OpenMLS) + pub fn from_bytes( + mls_provider: &XmtpOpenMlsProvider, + data: &[u8], + ) -> Result { + let kp_in: KeyPackageIn = KeyPackageIn::tls_deserialize_bytes(data)?; + let kp = kp_in.validate(mls_provider.crypto(), MLS_PROTOCOL_VERSION)?; + + Self::from_key_package(kp) + } + + pub fn installation_id(&self) -> Vec { + self.inner.leaf_node().signature_key().as_slice().to_vec() + } +} + +fn identity_to_wallet_address( + identity_bytes: &[u8], + pub_key_bytes: &[u8], +) -> Result { + let proto_value = Eip191AssociationProto::decode(identity_bytes)?; + let association = Eip191Association::from_proto_with_expected_address( + pub_key_bytes, + proto_value.clone(), + proto_value.wallet_address, + )?; + + Ok(association.address()) +}