diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 8d96444ec..e82d21ec5 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -1,26 +1,32 @@ // For format details, see https://aka.ms/devcontainer.json. For config options, see the // README at: https://github.com/devcontainers/templates/tree/main/src/docker-existing-dockerfile { - "name": "Existing Dockerfile", - "build": { - // Sets the run context to one level up instead of the .devcontainer folder. - "context": "..", - // Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename. - "dockerfile": "../Dockerfile" - } + "name": "Existing Dockerfile", + "build": { + // Sets the run context to one level up instead of the .devcontainer folder. + "context": "..", + // Update the 'dockerFile' property if you aren't using the standard 'Dockerfile' filename. + "dockerfile": "../Dockerfile" + }, - // Features to add to the dev container. More info: https://containers.dev/features. - // "features": {}, + // Features to add to the dev container. More info: https://containers.dev/features. + // "features": {}, - // Use 'forwardPorts' to make a list of ports inside the container available locally. - // "forwardPorts": [], + // Use 'forwardPorts' to make a list of ports inside the container available locally. + "forwardPorts": [], + "customizations": { + "vscode": { + "extensions": ["tamasfe.even-better-toml", "rust-lang.rust-analyzer"] + } + }, + "runArgs": ["--network=host"] - // Uncomment the next line to run commands after the container is created. - // "postCreateCommand": "cat /etc/os-release", + // Uncomment the next line to run commands after the container is created. + // "postCreateCommand": "cat /etc/os-release", - // Configure tool-specific properties. - // "customizations": {}, + // Configure tool-specific properties. + // "customizations": {}, - // Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root. - // "remoteUser": "devcontainer" + // Uncomment to connect as an existing user other than the container default. More info: https://aka.ms/dev-containers-non-root. + // "remoteUser": "devcontainer" } diff --git a/Cargo.lock b/Cargo.lock index 6969e84a7..37729ccfe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3099,10 +3099,10 @@ dependencies = [ "ethers", "hex", "log", - "openmls 0.5.0 (git+https://github.com/openmls/openmls)", - "openmls_basic_credential 0.2.0 (git+https://github.com/openmls/openmls)", - "openmls_rust_crypto 0.2.0 (git+https://github.com/openmls/openmls)", - "openmls_traits 0.2.0 (git+https://github.com/openmls/openmls)", + "openmls 0.5.0 (git+https://github.com/xmtp/openmls)", + "openmls_basic_credential 0.2.0 (git+https://github.com/xmtp/openmls)", + "openmls_rust_crypto 0.2.0 (git+https://github.com/xmtp/openmls)", + "openmls_traits 0.2.0 (git+https://github.com/xmtp/openmls)", "prost", "rand 0.8.5", "serde", @@ -3345,14 +3345,14 @@ dependencies = [ [[package]] name = "openmls" version = "0.5.0" -source = "git+https://github.com/openmls/openmls#c86f4b0d3b146477954322de7aea00a52df0c95c" +source = "git+https://github.com/xmtp/openmls?branch=main#d72380028c6c7e5e73f526a75c6f65bdaa93b6b4" dependencies = [ "backtrace", "itertools 0.10.5", "log", - "openmls_basic_credential 0.2.0 (git+https://github.com/openmls/openmls)", - "openmls_rust_crypto 0.2.0 (git+https://github.com/openmls/openmls)", - "openmls_traits 0.2.0 (git+https://github.com/openmls/openmls)", + "openmls_basic_credential 0.2.0 (git+https://github.com/xmtp/openmls?branch=main)", + "openmls_rust_crypto 0.2.0 (git+https://github.com/xmtp/openmls?branch=main)", + "openmls_traits 0.2.0 (git+https://github.com/xmtp/openmls?branch=main)", "rand 0.8.5", "rayon", "rstest", @@ -3387,10 +3387,10 @@ dependencies = [ [[package]] name = "openmls_basic_credential" version = "0.2.0" -source = "git+https://github.com/openmls/openmls#c86f4b0d3b146477954322de7aea00a52df0c95c" +source = "git+https://github.com/xmtp/openmls?branch=main#d72380028c6c7e5e73f526a75c6f65bdaa93b6b4" dependencies = [ "ed25519-dalek", - "openmls_traits 0.2.0 (git+https://github.com/openmls/openmls)", + "openmls_traits 0.2.0 (git+https://github.com/xmtp/openmls?branch=main)", "p256", "rand 0.8.5", "serde", @@ -3413,9 +3413,9 @@ dependencies = [ [[package]] name = "openmls_memory_keystore" version = "0.2.0" -source = "git+https://github.com/openmls/openmls#c86f4b0d3b146477954322de7aea00a52df0c95c" +source = "git+https://github.com/xmtp/openmls?branch=main#d72380028c6c7e5e73f526a75c6f65bdaa93b6b4" dependencies = [ - "openmls_traits 0.2.0 (git+https://github.com/openmls/openmls)", + "openmls_traits 0.2.0 (git+https://github.com/xmtp/openmls?branch=main)", "serde_json", "thiserror", ] @@ -3433,7 +3433,7 @@ dependencies = [ [[package]] name = "openmls_rust_crypto" version = "0.2.0" -source = "git+https://github.com/openmls/openmls#c86f4b0d3b146477954322de7aea00a52df0c95c" +source = "git+https://github.com/xmtp/openmls?branch=main#d72380028c6c7e5e73f526a75c6f65bdaa93b6b4" dependencies = [ "aes-gcm", "chacha20poly1305", @@ -3443,8 +3443,8 @@ dependencies = [ "hpke-rs", "hpke-rs-crypto", "hpke-rs-rust-crypto", - "openmls_memory_keystore 0.2.0 (git+https://github.com/openmls/openmls)", - "openmls_traits 0.2.0 (git+https://github.com/openmls/openmls)", + "openmls_memory_keystore 0.2.0 (git+https://github.com/xmtp/openmls?branch=main)", + "openmls_traits 0.2.0 (git+https://github.com/xmtp/openmls?branch=main)", "p256", "rand 0.8.5", "rand_chacha 0.3.1", @@ -3481,7 +3481,7 @@ dependencies = [ [[package]] name = "openmls_traits" version = "0.2.0" -source = "git+https://github.com/openmls/openmls#c86f4b0d3b146477954322de7aea00a52df0c95c" +source = "git+https://github.com/xmtp/openmls?branch=main#d72380028c6c7e5e73f526a75c6f65bdaa93b6b4" dependencies = [ "serde", "tls_codec", @@ -6554,10 +6554,10 @@ dependencies = [ "libsqlite3-sys", "log", "mockall", - "openmls 0.5.0 (git+https://github.com/xmtp/openmls)", - "openmls_basic_credential 0.2.0 (git+https://github.com/xmtp/openmls)", - "openmls_rust_crypto 0.2.0 (git+https://github.com/xmtp/openmls)", - "openmls_traits 0.2.0 (git+https://github.com/xmtp/openmls)", + "openmls 0.5.0 (git+https://github.com/xmtp/openmls?branch=main)", + "openmls_basic_credential 0.2.0 (git+https://github.com/xmtp/openmls?branch=main)", + "openmls_rust_crypto 0.2.0 (git+https://github.com/xmtp/openmls?branch=main)", + "openmls_traits 0.2.0 (git+https://github.com/xmtp/openmls?branch=main)", "prost", "rand 0.8.5", "serde", diff --git a/mls_validation_service/Cargo.toml b/mls_validation_service/Cargo.toml index 65f2defa9..58f82f423 100644 --- a/mls_validation_service/Cargo.toml +++ b/mls_validation_service/Cargo.toml @@ -11,11 +11,15 @@ path = "src/main.rs" prost = { version = "0.11", features = ["prost-derive"] } tokio = { version = "1.33.0", features = ["macros", "rt-multi-thread", "full"] } tonic = "^0.9" -xmtp_proto = { path = "../xmtp_proto", features = ["proto_full", "grpc", "tonic"] } -openmls = { git= "https://github.com/openmls/openmls", features = ["test-utils"] } -openmls_traits = { git= "https://github.com/openmls/openmls" } -openmls_basic_credential = { git= "https://github.com/openmls/openmls" } -openmls_rust_crypto = { git= "https://github.com/openmls/openmls" } +xmtp_proto = { path = "../xmtp_proto", features = [ + "proto_full", + "grpc", + "tonic", +] } +openmls = { git = "https://github.com/xmtp/openmls", features = ["test-utils"] } +openmls_traits = { git = "https://github.com/xmtp/openmls" } +openmls_basic_credential = { git = "https://github.com/xmtp/openmls" } +openmls_rust_crypto = { git = "https://github.com/xmtp/openmls" } xmtp_mls = { path = "../xmtp_mls" } serde = "1.0.189" hex = "0.4.3" @@ -25,4 +29,4 @@ env_logger = "0.10.0" [dev-dependencies] ethers = "2.0.10" -rand = "0.8.5" \ No newline at end of file +rand = "0.8.5" diff --git a/mls_validation_service/src/handlers.rs b/mls_validation_service/src/handlers.rs index 9de22f366..4d52117e8 100644 --- a/mls_validation_service/src/handlers.rs +++ b/mls_validation_service/src/handlers.rs @@ -5,6 +5,7 @@ use openmls::{ use openmls_rust_crypto::OpenMlsRustCrypto; use openmls_traits::OpenMlsProvider; use tonic::{Request, Response, Status}; +use xmtp_mls::utils::id::serialize_group_id; use xmtp_proto::xmtp::mls_validation::v1::{ validate_group_messages_response::ValidationResponse as ValidateGroupMessageValidationResponse, validate_key_packages_response::ValidationResponse as ValidateKeyPackageValidationResponse, @@ -12,7 +13,7 @@ use xmtp_proto::xmtp::mls_validation::v1::{ ValidateGroupMessagesResponse, ValidateKeyPackagesRequest, ValidateKeyPackagesResponse, }; -use crate::validation_helpers::{hex_encode, identity_to_wallet_address}; +use crate::validation_helpers::identity_to_wallet_address; #[derive(Debug, Default)] pub struct ValidationService {} @@ -93,9 +94,7 @@ fn validate_group_message(message: Vec) -> Result String { - hex::encode(key) -} - pub fn identity_to_wallet_address(identity: &[u8], pub_key: &[u8]) -> Result { let proto_value = Eip191AssociationProto::decode(identity).map_err(|e| format!("{:?}", e))?; let association = Eip191Association::from_proto_with_expected_address( diff --git a/xmtp_mls/Cargo.toml b/xmtp_mls/Cargo.toml index f25854025..144eb5a79 100644 --- a/xmtp_mls/Cargo.toml +++ b/xmtp_mls/Cargo.toml @@ -19,27 +19,33 @@ native = ["libsqlite3-sys/bundled-sqlcipher-vendored-openssl"] [dependencies] anyhow = "1.0.71" async-trait = "0.1.68" -diesel = { version = "2.1.3", features = ["sqlite", "r2d2", "returning_clauses_for_sqlite_3_35"] } +diesel = { version = "2.1.3", features = [ + "sqlite", + "r2d2", + "returning_clauses_for_sqlite_3_35", +] } diesel_migrations = { version = "2.1.0", features = ["sqlite"] } ethers = "2.0.4" ethers-core = "2.0.4" futures = "0.3.28" hex = "0.4.3" -libsqlite3-sys = { version = "0.26.0", optional = true} -openmls = { git= "https://github.com/xmtp/openmls", features = ["test-utils"] } -openmls_traits = { git= "https://github.com/xmtp/openmls" } -openmls_basic_credential = { git= "https://github.com/xmtp/openmls" } -openmls_rust_crypto = { git= "https://github.com/xmtp/openmls" } +libsqlite3-sys = { version = "0.26.0", optional = true } +openmls = { git = "https://github.com/xmtp/openmls", branch = "main", features = [ + "test-utils", +] } +openmls_traits = { git = "https://github.com/xmtp/openmls", branch = "main" } +openmls_basic_credential = { git = "https://github.com/xmtp/openmls", branch = "main" } +openmls_rust_crypto = { git = "https://github.com/xmtp/openmls", branch = "main" } prost = { version = "0.11", features = ["prost-derive"] } rand = "0.8.5" serde = "1.0.160" serde_json = "1.0.96" thiserror = "1.0.40" -tokio = { version = "1.28.1", features = ["macros"] } +tokio = { version = "1.28.1", features = ["macros"] } log = "0.4.17" -tracing = "0.1.37" +tracing = "0.1.37" toml = "0.7.4" -xmtp_cryptography = { path = "../xmtp_cryptography"} +xmtp_cryptography = { path = "../xmtp_cryptography" } xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] } tls_codec = "0.3.0" diff --git a/xmtp_mls/src/client.rs b/xmtp_mls/src/client.rs index 56349e039..565f9b908 100644 --- a/xmtp_mls/src/client.rs +++ b/xmtp_mls/src/client.rs @@ -1,18 +1,21 @@ use std::collections::HashSet; -use openmls::prelude::TlsSerializeTrait; +use openmls::{ + framing::{MlsMessageIn, MlsMessageInBody}, + messages::Welcome, + prelude::TlsSerializeTrait, +}; use thiserror::Error; -use tls_codec::Error as TlsSerializationError; +use tls_codec::{Deserialize, Error as TlsSerializationError}; use xmtp_proto::api_client::{XmtpApiClient, XmtpMlsClient}; use crate::{ api_client_wrapper::{ApiClientWrapper, IdentityUpdate}, - builder::{ClientBuilder, IdentityStrategy}, - configuration::KEY_PACKAGE_TOP_UP_AMOUNT, groups::MlsGroup, identity::Identity, storage::{group::GroupMembershipState, DbConnection, EncryptedMessageStore, StorageError}, types::Address, + utils::topic::get_welcome_topic, verified_key_package::{KeyPackageVerificationError, VerifiedKeyPackage}, xmtp_openmls_provider::XmtpOpenMlsProvider, }; @@ -41,6 +44,8 @@ pub enum ClientError { Serialization(#[from] TlsSerializationError), #[error("key package verification: {0}")] KeyPackageVerification(#[from] KeyPackageVerificationError), + #[error("message processing: {0}")] + MessageProcessing(#[from] crate::groups::MessageProcessingError), #[error("generic:{0}")] Generic(String), } @@ -59,10 +64,10 @@ impl From<&str> for ClientError { #[derive(Debug)] pub struct Client { - pub api_client: ApiClientWrapper, + pub(crate) api_client: ApiClientWrapper, pub(crate) _network: Network, pub(crate) identity: Identity, - pub store: EncryptedMessageStore, // Temporarily exposed outside crate for CLI client + pub(crate) store: EncryptedMessageStore, } impl<'a, ApiClient> Client @@ -147,23 +152,6 @@ where Ok(()) } - pub async fn top_up_key_packages(&self) -> Result<(), ClientError> { - let mut connection = self.store.conn()?; - let mls_provider = XmtpOpenMlsProvider::new(&mut connection); - let key_packages: Result>, ClientError> = (0..KEY_PACKAGE_TOP_UP_AMOUNT) - .map(|_| -> Result, ClientError> { - let kp = self.identity.new_key_package(&mls_provider)?; - let kp_bytes = kp.tls_serialize_detached()?; - - Ok(kp_bytes) - }) - .collect(); - - self.api_client.upload_key_packages(key_packages?).await?; - - Ok(()) - } - async fn get_all_active_installation_ids( &self, wallet_addresses: Vec, @@ -228,6 +216,43 @@ where .collect::>()?) } + // Download all unread welcome messages and convert to groups. + // Returns any new groups created in the operation + pub async fn sync_welcomes(&self) -> Result>, ClientError> { + let welcome_topic = get_welcome_topic(&self.installation_public_key()); + let mut conn = self.store.conn()?; + let provider = self.mls_provider(); + // TODO: Use the last_message_timestamp_ns field on the TopicRefreshState to only fetch new messages + // Waiting for more atomic update methods + let envelopes = self.api_client.read_topic(&welcome_topic, 0).await?; + + let groups: Vec> = envelopes + .into_iter() + .filter_map(|envelope| { + // TODO: Wrap in a transaction + let welcome = match extract_welcome(&envelope.message) { + Ok(welcome) => welcome, + Err(err) => { + log::error!("failed to extract welcome: {}", err); + return None; + } + }; + + // TODO: Update last_message_timestamp_ns on success or non-retryable error + // TODO: Abort if error is retryable + match MlsGroup::create_from_welcome(self, &mut conn, &provider, welcome) { + Ok(mls_group) => Some(mls_group), + Err(err) => { + log::error!("failed to create group from welcome: {}", err); + None + } + } + }) + .collect(); + + Ok(groups) + } + pub fn account_address(&self) -> Address { self.identity.account_address.clone() } @@ -237,6 +262,17 @@ where } } +fn extract_welcome(welcome_bytes: &Vec) -> Result { + // let welcome_proto = WelcomeMessageProto::decode(&mut welcome_bytes.as_slice())?; + let welcome = MlsMessageIn::tls_deserialize(&mut welcome_bytes.as_slice())?; + match welcome.extract() { + MlsMessageInBody::Welcome(welcome) => Ok(welcome), + _ => Err(ClientError::Generic( + "unexpected message type in welcome".to_string(), + )), + } +} + #[cfg(test)] mod tests { use xmtp_cryptography::utils::generate_local_wallet; @@ -267,38 +303,6 @@ mod tests { assert_eq!(installation_ids.len(), 1); } - #[tokio::test] - async fn test_top_up_key_packages() { - let wallet = generate_local_wallet(); - let wallet_address = wallet.get_address(); - let client = ClientBuilder::new_test_client(wallet.clone().into()).await; - - client.register_identity().await.unwrap(); - client.top_up_key_packages().await.unwrap(); - - let key_packages = client - .get_key_packages_for_wallet_addresses(vec![wallet_address.clone()]) - .await - .unwrap(); - - assert_eq!(key_packages.len(), 1); - - let key_package = key_packages.first().unwrap(); - assert_eq!(key_package.wallet_address, wallet_address); - - let key_packages_2 = client - .get_key_packages_for_wallet_addresses(vec![wallet_address.clone()]) - .await - .unwrap(); - - assert_eq!(key_packages_2.len(), 1); - - // Ensure we got back different key packages - let key_package_2 = key_packages_2.first().unwrap(); - assert_eq!(key_package_2.wallet_address, wallet_address); - assert!(!(key_package_2.eq(key_package))); - } - #[tokio::test] async fn test_find_groups() { let client = ClientBuilder::new_test_client(generate_local_wallet().into()).await; @@ -310,4 +314,41 @@ mod tests { assert_eq!(groups[0].group_id, group_1.group_id); assert_eq!(groups[1].group_id, group_2.group_id); } + + #[tokio::test] + async fn test_sync_welcomes() { + let alice = ClientBuilder::new_test_client(generate_local_wallet().into()).await; + alice.register_identity().await.unwrap(); + let bob = ClientBuilder::new_test_client(generate_local_wallet().into()).await; + bob.register_identity().await.unwrap(); + + let conn = &mut alice.store.conn().unwrap(); + let alice_bob_group = alice.create_group().unwrap(); + alice_bob_group + .add_members_by_installation_id(vec![bob.installation_public_key()]) + .await + .unwrap(); + + // Manually mark as committed + // TODO: Replace with working synchronization once we can add members end to end + let intents = alice + .store + .find_group_intents(conn, alice_bob_group.group_id.clone(), None, None) + .unwrap(); + let intent = intents.first().unwrap(); + // Set the intent to committed manually + alice + .store + .set_group_intent_committed(conn, intent.id) + .unwrap(); + + alice_bob_group.post_commit(conn).await.unwrap(); + + let bob_received_groups = bob.sync_welcomes().await.unwrap(); + assert_eq!(bob_received_groups.len(), 1); + assert_eq!( + bob_received_groups.first().unwrap().group_id, + alice_bob_group.group_id + ); + } } diff --git a/xmtp_mls/src/configuration.rs b/xmtp_mls/src/configuration.rs index 577acc965..f04ec1547 100644 --- a/xmtp_mls/src/configuration.rs +++ b/xmtp_mls/src/configuration.rs @@ -6,5 +6,3 @@ pub const CIPHERSUITE: Ciphersuite = Ciphersuite::MLS_128_DHKEMX25519_CHACHA20POLY1305_SHA256_Ed25519; pub const MLS_PROTOCOL_VERSION: ProtocolVersion = ProtocolVersion::Mls10; - -pub const KEY_PACKAGE_TOP_UP_AMOUNT: u16 = 100; diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index bb443d620..a3a62a6aa 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -1,30 +1,38 @@ mod intents; +#[cfg(test)] +use std::println as debug; + use intents::SendMessageIntentData; +#[cfg(not(test))] +use log::debug; use openmls::{ prelude::{ CredentialWithKey, CryptoConfig, GroupId, LeafNodeIndex, MlsGroup as OpenMlsGroup, - MlsGroupConfig, WireFormatPolicy, + MlsGroupConfig, MlsMessageIn, MlsMessageInBody, PrivateMessageIn, ProcessedMessage, + ProcessedMessageContent, Sender, Welcome as MlsWelcome, WireFormatPolicy, }, prelude_test::KeyPackage, }; use openmls_traits::OpenMlsProvider; +use std::mem::{discriminant, Discriminant}; use thiserror::Error; -use tls_codec::Serialize; -use xmtp_proto::api_client::{XmtpApiClient, XmtpMlsClient}; +use tls_codec::{Deserialize, Serialize}; +use xmtp_proto::api_client::{Envelope, XmtpApiClient, XmtpMlsClient}; use self::intents::{AddMembersIntentData, IntentError, PostCommitAction, RemoveMembersIntentData}; use crate::{ api_client_wrapper::WelcomeMessage, client::ClientError, configuration::CIPHERSUITE, + identity::Identity, storage::{ group::{GroupMembershipState, StoredGroup}, group_intent::{IntentKind, IntentState, NewGroupIntent, StoredGroupIntent}, group_message::{GroupMessageKind, StoredGroupMessage}, DbConnection, EncryptedMessageStore, StorageError, }, - utils::{hash::sha256, time::now_ns, topic::get_group_topic}, + utils::{hash::sha256, id::get_message_id, time::now_ns, topic::get_group_topic}, xmtp_openmls_provider::XmtpOpenMlsProvider, Client, Delete, Store, }; @@ -51,14 +59,35 @@ pub enum GroupError { GroupCreate(#[from] openmls::prelude::NewGroupError), #[error("self update: {0}")] SelfUpdate(#[from] openmls::group::SelfUpdateError), + #[error("welcome error: {0}")] + WelcomeError(#[from] openmls::prelude::WelcomeError), #[error("client: {0}")] Client(#[from] ClientError), + #[error("receive errors: {0:?}")] + ReceiveError(Vec), #[error("generic: {0}")] Generic(String), #[error("diesel error {0}")] Diesel(#[from] diesel::result::Error), } +#[derive(Debug, Error)] +pub enum MessageProcessingError { + #[error("[{message_time_ns:?}] invalid sender with credential: {credential:?}")] + InvalidSender { + message_time_ns: u64, + credential: Vec, + }, + #[error("openmls process message error: {0}")] + OpenMlsProcessMessage(#[from] openmls::prelude::ProcessMessageError), + #[error("storage error: {0}")] + Storage(#[from] crate::storage::StorageError), + #[error("tls deserialization: {0}")] + TlsDeserialization(#[from] tls_codec::Error), + #[error("unsupported message type: {0:?}")] + UnsupportedMessageType(Discriminant), +} + pub struct MlsGroup<'c, ApiClient> { pub group_id: Vec, pub created_at_ns: i64, @@ -99,7 +128,6 @@ where &provider, &client.identity.installation_keys, &build_group_config(), - // TODO: Confirm I should be using the installation keys here CredentialWithKey { credential: client.identity.credential.clone(), signature_key: client.identity.installation_keys.to_public_vec().into(), @@ -114,6 +142,24 @@ where Ok(Self::new(client, group_id, stored_group.created_at_ns)) } + pub fn create_from_welcome( + client: &'c Client, + conn: &mut DbConnection, + provider: &XmtpOpenMlsProvider, + welcome: MlsWelcome, + ) -> Result { + let mut mls_group = + OpenMlsGroup::new_from_welcome(provider, &build_group_config(), welcome, None)?; + mls_group.save(provider.key_store())?; + + let group_id = mls_group.group_id().to_vec(); + let stored_group = + StoredGroup::new(group_id.clone(), now_ns(), GroupMembershipState::Pending); + stored_group.store(conn)?; + + Ok(Self::new(client, group_id, stored_group.created_at_ns)) + } + pub fn find_messages( &self, kind: Option, @@ -134,6 +180,124 @@ where Ok(messages) } + fn validate_message_sender( + &self, + openmls_group: &mut OpenMlsGroup, + decrypted_message: &ProcessedMessage, + envelope_timestamp_ns: u64, + ) -> Result<(String, Vec), MessageProcessingError> { + let mut sender_account_address = None; + let mut sender_installation_id = None; + if let Sender::Member(leaf_node_index) = decrypted_message.sender() { + if let Some(member) = openmls_group.member_at(*leaf_node_index) { + if member.credential.eq(decrypted_message.credential()) { + sender_account_address = Identity::get_validated_account_address( + member.credential.identity(), + &member.signature_key, + ) + .ok(); + sender_installation_id = Some(member.signature_key); + } + } + } + + if sender_account_address.is_none() { + return Err(MessageProcessingError::InvalidSender { + message_time_ns: envelope_timestamp_ns, + credential: decrypted_message.credential().identity().to_vec(), + }); + } + Ok(( + sender_account_address.unwrap(), + sender_installation_id.unwrap(), + )) + } + + fn process_private_message( + &self, + openmls_group: &mut OpenMlsGroup, + provider: &XmtpOpenMlsProvider, + message: PrivateMessageIn, + envelope_timestamp_ns: u64, + ) -> Result<(), MessageProcessingError> { + let decrypted_message = openmls_group.process_message(provider, message)?; + let (sender_account_address, sender_installation_id) = + self.validate_message_sender(openmls_group, &decrypted_message, envelope_timestamp_ns)?; + + match decrypted_message.into_content() { + ProcessedMessageContent::ApplicationMessage(application_message) => { + let message_bytes = application_message.into_bytes(); + let message_id = + get_message_id(&message_bytes, &self.group_id, envelope_timestamp_ns); + let message = StoredGroupMessage { + id: message_id, + group_id: self.group_id.clone(), + decrypted_message_bytes: message_bytes, + sent_at_ns: envelope_timestamp_ns as i64, + kind: GroupMessageKind::Application, + sender_installation_id, + sender_wallet_address: sender_account_address, + }; + message.store(&mut self.client.store.conn()?)?; + } + ProcessedMessageContent::ProposalMessage(_proposal_ptr) => { + // intentionally left blank. + } + ProcessedMessageContent::ExternalJoinProposalMessage(_external_proposal_ptr) => { + // intentionally left blank. + } + ProcessedMessageContent::StagedCommitMessage(_commit_ptr) => { + // intentionally left blank. + } + } + Ok(()) + } + + pub fn process_messages(&self, envelopes: Vec) -> Result<(), GroupError> { + let provider = self.client.mls_provider(); + let mut openmls_group = self.load_mls_group(&provider)?; + let receive_errors: Vec = envelopes + .into_iter() + .map(|envelope| -> Result<(), MessageProcessingError> { + let mls_message_in = MlsMessageIn::tls_deserialize_exact(&envelope.message)?; + + match mls_message_in.extract() { + MlsMessageInBody::PrivateMessage(message) => self.process_private_message( + &mut openmls_group, + &provider, + message, + envelope.timestamp_ns, + ), + other => Err(MessageProcessingError::UnsupportedMessageType( + discriminant(&other), + )), + } + }) + .filter(|result| result.is_err()) + .map(|result| result.unwrap_err()) + .collect(); + openmls_group.save(provider.key_store())?; // TODO handle concurrency + + if receive_errors.is_empty() { + Ok(()) + } else { + Err(GroupError::ReceiveError(receive_errors)) + } + } + + pub async fn receive(&self) -> Result<(), GroupError> { + let topic = get_group_topic(&self.group_id); + let envelopes = self + .client + .api_client + .read_topic( + &topic, 0, // TODO: query from last query point + ) + .await?; + debug!("Received {} envelopes", envelopes.len()); + self.process_messages(envelopes) + } + pub async fn send_message(&self, message: &[u8]) -> Result<(), GroupError> { let mut conn = self.client.store.conn()?; let intent_data: Vec = SendMessageIntentData::new(message.to_vec()).into(); @@ -378,7 +542,8 @@ mod tests { use xmtp_cryptography::utils::generate_local_wallet; use crate::{ - builder::ClientBuilder, storage::EncryptedMessageStore, utils::topic::get_welcome_topic, + builder::ClientBuilder, groups::GroupError, storage::EncryptedMessageStore, + utils::topic::get_welcome_topic, }; #[tokio::test] @@ -399,6 +564,21 @@ mod tests { assert_eq!(messages.len(), 1) } + #[tokio::test] + async fn test_receive_self_message() { + let wallet = generate_local_wallet(); + let client = ClientBuilder::new_test_client(wallet.into()).await; + let group = client.create_group().expect("create group"); + group.send_message(b"hello").await.expect("send message"); + + let result = group.receive().await; + if let GroupError::ReceiveError(errors) = result.err().unwrap() { + assert_eq!(errors.len(), 1); + } else { + panic!("expected GroupError::ReceiveError") + } + } + #[tokio::test] async fn test_add_members() { let client = ClientBuilder::new_test_client(generate_local_wallet().into()).await; diff --git a/xmtp_mls/src/identity.rs b/xmtp_mls/src/identity.rs index ad47654c0..b33ebbc25 100644 --- a/xmtp_mls/src/identity.rs +++ b/xmtp_mls/src/identity.rs @@ -1,6 +1,8 @@ use openmls::{ + extensions::LastResortExtension, prelude::{ - Credential, CredentialType, CredentialWithKey, CryptoConfig, KeyPackage, KeyPackageNewError, + Capabilities, Credential, CredentialType, CredentialWithKey, CryptoConfig, Extension, + ExtensionType, Extensions, KeyPackage, KeyPackageNewError, }, versions::ProtocolVersion, }; @@ -32,6 +34,8 @@ pub enum IdentityError { StorageError(#[from] StorageError), #[error("generating key package")] KeyPackageGenerationError(#[from] KeyPackageNewError), + #[error("deserialization")] + Deserialization(#[from] prost::DecodeError), } #[derive(Debug)] @@ -52,27 +56,13 @@ impl<'a> Identity { let credential = Identity::create_credential(&signature_keys, owner)?; - // The builder automatically stores it in the key store - // TODO: Make OpenMLS not delete this once used - let _last_resort_key_package = KeyPackage::builder().build( - CryptoConfig { - ciphersuite: CIPHERSUITE, - version: ProtocolVersion::default(), - }, - provider, - &signature_keys, - CredentialWithKey { - credential: credential.clone(), - signature_key: signature_keys.to_public_vec().into(), - }, - )?; - let identity = Self { account_address: owner.get_address(), installation_keys: signature_keys, credential, }; + identity.new_key_package(provider)?; StoredIdentity::from(&identity).store(&mut store.conn()?)?; // StoredIdentity::from(&identity).store(*provider.conn().borrow_mut())?; @@ -81,22 +71,37 @@ impl<'a> Identity { Ok(identity) } + // ONLY CREATES LAST RESORT KEY PACKAGES + // TODO: Implement key package rotation https://github.com/xmtp/libxmtp/issues/293 pub(crate) fn new_key_package( &self, provider: &XmtpOpenMlsProvider, ) -> Result { - let kp = KeyPackage::builder().build( - CryptoConfig { - ciphersuite: CIPHERSUITE, - version: ProtocolVersion::default(), - }, - provider, - &self.installation_keys, - CredentialWithKey { - credential: self.credential.clone(), - signature_key: self.installation_keys.to_public_vec().into(), - }, - )?; + let last_resort = Extension::LastResort(LastResortExtension::default()); + let extensions = Extensions::single(last_resort); + let capabilities = Capabilities::new( + None, + Some(&[CIPHERSUITE]), + Some(&[ExtensionType::LastResort]), + None, + None, + ); + // TODO: Set expiration + let kp = KeyPackage::builder() + .leaf_node_capabilities(capabilities) + .key_package_extensions(extensions) + .build( + CryptoConfig { + ciphersuite: CIPHERSUITE, + version: ProtocolVersion::default(), + }, + provider, + &self.installation_keys, + CredentialWithKey { + credential: self.credential.clone(), + signature_key: self.installation_keys.to_public_vec().into(), + }, + )?; Ok(kp) } @@ -119,11 +124,26 @@ impl<'a> Identity { // Serialize into credential Ok(Credential::new(association_proto.encode_to_vec(), CredentialType::Basic).unwrap()) } + + pub(crate) fn get_validated_account_address( + credential: &[u8], + installation_public_key: &[u8], + ) -> Result { + let proto = Eip191AssociationProto::decode(credential)?; + let expected_wallet_address = proto.wallet_address.clone(); + let association = Eip191Association::from_proto_with_expected_address( + installation_public_key, + proto, + expected_wallet_address, + )?; + + Ok(association.address()) + } } #[cfg(test)] mod tests { - use std::error::Error; + use openmls::prelude::ExtensionType; use xmtp_cryptography::utils::generate_local_wallet; use super::Identity; @@ -134,11 +154,19 @@ mod tests { let store = EncryptedMessageStore::new_test(); let mut conn = store.conn().unwrap(); let provider = XmtpOpenMlsProvider::new(&mut conn); - let result = Identity::new(&store, &provider, &generate_local_wallet()); - if let Err(e) = result { - println!("{:?}", e); - println!("{:?}", e.source()); - panic!("d"); - } + Identity::new(&store, &provider, &generate_local_wallet()).unwrap(); + } + + #[test] + fn test_key_package_extensions() { + let store = EncryptedMessageStore::new_test(); + let provider = XmtpOpenMlsProvider::new(&store); + let identity = Identity::new(&store, &provider, &generate_local_wallet()).unwrap(); + + let new_key_package = identity.new_key_package(&provider).unwrap(); + assert!(new_key_package + .extensions() + .contains(ExtensionType::LastResort)); + assert!(new_key_package.last_resort()) } } diff --git a/xmtp_mls/src/utils/id.rs b/xmtp_mls/src/utils/id.rs new file mode 100644 index 000000000..9d3c10a8a --- /dev/null +++ b/xmtp_mls/src/utils/id.rs @@ -0,0 +1,19 @@ +use super::hash::sha256; + +pub fn serialize_group_id(group_id: &[u8]) -> String { + // TODO: I wonder if we really want to be base64 encoding this or if we can treat it as a + // slice + hex::encode(group_id) +} + +pub fn get_message_id( + decrypted_message_bytes: &[u8], + group_id: &[u8], + envelope_timestamp_ns: u64, +) -> Vec { + let mut id_vec = Vec::new(); + id_vec.extend_from_slice(group_id); + id_vec.extend_from_slice(&envelope_timestamp_ns.to_be_bytes()); + id_vec.extend_from_slice(decrypted_message_bytes); + sha256(&id_vec) +} diff --git a/xmtp_mls/src/utils/mod.rs b/xmtp_mls/src/utils/mod.rs index 79f3f4814..97be93fa1 100644 --- a/xmtp_mls/src/utils/mod.rs +++ b/xmtp_mls/src/utils/mod.rs @@ -1,4 +1,5 @@ pub mod hash; +pub mod id; #[cfg(test)] pub mod test; pub mod time; diff --git a/xmtp_mls/src/utils/topic.rs b/xmtp_mls/src/utils/topic.rs index 4685984fc..93743f583 100644 --- a/xmtp_mls/src/utils/topic.rs +++ b/xmtp_mls/src/utils/topic.rs @@ -1,5 +1,7 @@ -pub fn get_group_topic(group_id: &Vec) -> String { - format!("/xmtp/3/g-{}/proto", hex::encode(group_id)) +use crate::utils::id::serialize_group_id; + +pub fn get_group_topic(group_id: &[u8]) -> String { + format!("/xmtp/3/g-{}/proto", serialize_group_id(group_id)) } pub fn get_welcome_topic(installation_id: &Vec) -> String { diff --git a/xmtp_mls/src/verified_key_package.rs b/xmtp_mls/src/verified_key_package.rs index ab7f557be..9cf248600 100644 --- a/xmtp_mls/src/verified_key_package.rs +++ b/xmtp_mls/src/verified_key_package.rs @@ -1,13 +1,11 @@ use openmls::prelude::{KeyPackage, KeyPackageIn, KeyPackageVerifyError}; use openmls_traits::OpenMlsProvider; -use prost::{DecodeError, Message}; use thiserror::Error; use tls_codec::{Deserialize, Error as TlsSerializationError}; -use xmtp_proto::xmtp::v3::message_contents::Eip191Association as Eip191AssociationProto; use crate::{ - association::{AssociationError, Eip191Association}, configuration::MLS_PROTOCOL_VERSION, + identity::{Identity, IdentityError}, xmtp_openmls_provider::XmtpOpenMlsProvider, }; @@ -17,10 +15,8 @@ pub enum KeyPackageVerificationError { Serialization(#[from] TlsSerializationError), #[error("mls validation: {0}")] MlsValidation(#[from] KeyPackageVerifyError), - #[error("association: {0}")] - Association(#[from] AssociationError), - #[error("decode: {0}")] - Decode(#[from] DecodeError), + #[error("identity: {0}")] + Identity(#[from] IdentityError), #[error("generic: {0}")] Generic(String), } @@ -66,15 +62,11 @@ impl VerifiedKeyPackage { } fn identity_to_wallet_address( - identity_bytes: &[u8], - pub_key_bytes: &[u8], + credential_bytes: &[u8], + installation_key_bytes: &[u8], ) -> Result { - let proto_value = Eip191AssociationProto::decode(identity_bytes)?; - let association = Eip191Association::from_proto_with_expected_address( - pub_key_bytes, - proto_value.clone(), - proto_value.wallet_address, - )?; - - Ok(association.address()) + Ok(Identity::get_validated_account_address( + credential_bytes, + installation_key_bytes, + )?) }