From c2c28b4273aecbe80301a08f3f91410b1ff10045 Mon Sep 17 00:00:00 2001 From: cameronvoell Date: Wed, 18 Dec 2024 16:21:04 -0800 Subject: [PATCH] Saving content types to db --- Cargo.lock | 2 + xmtp_content_types/Cargo.toml | 2 + xmtp_content_types/src/lib.rs | 67 ++++++++++++++++++- xmtp_mls/src/groups/mls_sync.rs | 37 ++++++++-- xmtp_mls/src/groups/mod.rs | 46 ++++++++++++- .../storage/encrypted_store/group_message.rs | 38 ++++++++--- 6 files changed, 173 insertions(+), 19 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cf5f4a245..0908ef7f8 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7319,8 +7319,10 @@ dependencies = [ name = "xmtp_content_types" version = "0.1.0" dependencies = [ + "diesel", "prost", "rand", + "serde", "thiserror 2.0.6", "tonic", "xmtp_common", diff --git a/xmtp_content_types/Cargo.toml b/xmtp_content_types/Cargo.toml index 2b7c506d1..52a4599d5 100644 --- a/xmtp_content_types/Cargo.toml +++ b/xmtp_content_types/Cargo.toml @@ -8,6 +8,8 @@ license.workspace = true thiserror = { workspace = true } prost = { workspace = true, features = ["prost-derive"] } rand = { workspace = true } +diesel = { workspace = true } +serde = { workspace = true } # XMTP/Local xmtp_proto = { workspace = true, features = ["convert"] } diff --git a/xmtp_content_types/src/lib.rs b/xmtp_content_types/src/lib.rs index 04a7a3fb9..ce89ed9f7 100644 --- a/xmtp_content_types/src/lib.rs +++ b/xmtp_content_types/src/lib.rs @@ -2,13 +2,74 @@ pub mod group_updated; pub mod membership_change; pub mod text; +use diesel::{ + backend::Backend, + deserialize::{self, FromSql, FromSqlRow}, + expression::AsExpression, + serialize::{self, IsNull, Output, ToSql}, + sql_types::Integer, + sqlite::Sqlite, +}; +use serde::{Deserialize, Serialize}; use thiserror::Error; use xmtp_proto::xmtp::mls::message_contents::{ContentTypeId, EncodedContent}; +/// ContentType and their corresponding string representation +/// are derived from the `ContentTypeId` enum in the xmtp-proto crate +/// that each content type in this crate establishes for itself +#[repr(i32)] +#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, FromSqlRow, AsExpression)] +#[diesel(sql_type = diesel::sql_types::Integer)] pub enum ContentType { - GroupMembershipChange, - GroupUpdated, - Text, + Unknown = 0, + Text = 1, + GroupMembershipChange = 2, + GroupUpdated = 3, +} + +impl ContentType { + pub fn from_string(type_id: &str) -> Self { + match type_id { + "text" => Self::Text, + "group_membership_change" => Self::GroupMembershipChange, + "group_updated" => Self::GroupUpdated, + _ => Self::Unknown, + } + } + + pub fn to_string(&self) -> &'static str { + match self { + Self::Unknown => "unknown", + Self::Text => "text", + Self::GroupMembershipChange => "group_membership_change", + Self::GroupUpdated => "group_updated", + } + } +} + +impl ToSql for ContentType +where + i32: ToSql, +{ + fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result { + out.set_value(*self as i32); + Ok(IsNull::No) + } +} + +impl FromSql for ContentType +where + i32: FromSql, +{ + fn from_sql(bytes: ::RawValue<'_>) -> deserialize::Result { + match i32::from_sql(bytes)? { + 0 => Ok(ContentType::Unknown), + 1 => Ok(ContentType::Text), + 2 => Ok(ContentType::GroupMembershipChange), + 3 => Ok(ContentType::GroupUpdated), + x => Err(format!("Unrecognized variant {}", x).into()), + } + } } #[derive(Debug, Error)] diff --git a/xmtp_mls/src/groups/mls_sync.rs b/xmtp_mls/src/groups/mls_sync.rs index 87109e81a..044e164ac 100644 --- a/xmtp_mls/src/groups/mls_sync.rs +++ b/xmtp_mls/src/groups/mls_sync.rs @@ -44,7 +44,7 @@ use hmac::{Hmac, Mac}; use openmls::{ credentials::BasicCredential, extensions::Extensions, - framing::{ContentType, ProtocolMessage}, + framing::{ContentType as MlsContentType, ProtocolMessage}, group::{GroupEpoch, StagedCommit}, key_packages::KeyPackage, prelude::{ @@ -67,7 +67,7 @@ use std::{ use thiserror::Error; use tracing::debug; use xmtp_common::{retry_async, Retry, RetryableError}; -use xmtp_content_types::{group_updated::GroupUpdatedCodec, CodecError, ContentCodec}; +use xmtp_content_types::{group_updated::GroupUpdatedCodec, CodecError, ContentCodec, ContentType}; use xmtp_id::{InboxId, InboxIdRef}; use xmtp_proto::xmtp::mls::{ api::v1::{ @@ -546,6 +546,7 @@ where })) => { let message_id = calculate_message_id(&self.group_id, &content, &idempotency_key); + let queryable_content_fields = Self::extract_queryable_content_fields(&content); StoredGroupMessage { id: message_id, group_id: self.group_id.clone(), @@ -555,6 +556,10 @@ where sender_installation_id, sender_inbox_id, delivery_status: DeliveryStatus::Published, + content_type: queryable_content_fields.content_type, + version_major: queryable_content_fields.version_major, + version_minor: queryable_content_fields.version_minor, + authority_id: queryable_content_fields.authority_id, } .store_or_ignore(provider.conn_ref())? } @@ -583,6 +588,10 @@ where sender_installation_id, sender_inbox_id: sender_inbox_id.clone(), delivery_status: DeliveryStatus::Published, + content_type: ContentType::Unknown, + version_major: 0, + version_minor: 0, + authority_id: "unknown".to_string(), } .store_or_ignore(provider.conn_ref())?; @@ -612,6 +621,10 @@ where sender_installation_id, sender_inbox_id, delivery_status: DeliveryStatus::Published, + content_type: ContentType::Unknown, + version_major: 0, + version_minor: 0, + authority_id: "unknown".to_string(), } .store_or_ignore(provider.conn_ref())?; @@ -712,7 +725,7 @@ where discriminant(&other), )), }?; - if !allow_epoch_increment && message.content_type() == ContentType::Commit { + if !allow_epoch_increment && message.content_type() == MlsContentType::Commit { return Err(GroupMessageProcessingError::EpochIncrementNotAllowed); } @@ -933,7 +946,19 @@ where encoded_payload_bytes.as_slice(), ×tamp_ns.to_string(), ); - + let content_type = match encoded_payload.r#type { + Some(ct) => ct, + None => { + tracing::warn!("Missing content type in encoded payload, using default values"); + // Default content type values + xmtp_proto::xmtp::mls::message_contents::ContentTypeId { + authority_id: "unknown".to_string(), + type_id: "unknown".to_string(), + version_major: 0, + version_minor: 0, + } + } + }; let msg = StoredGroupMessage { id: message_id, group_id: group_id.to_vec(), @@ -943,6 +968,10 @@ where sender_installation_id, sender_inbox_id, delivery_status: DeliveryStatus::Published, + content_type: ContentType::from_string(&content_type.type_id), + version_major: content_type.version_major as i32, + version_minor: content_type.version_minor as i32, + authority_id: content_type.authority_id.to_string(), }; msg.store_or_ignore(conn)?; diff --git a/xmtp_mls/src/groups/mod.rs b/xmtp_mls/src/groups/mod.rs index 35662bc7d..e23362c08 100644 --- a/xmtp_mls/src/groups/mod.rs +++ b/xmtp_mls/src/groups/mod.rs @@ -35,6 +35,7 @@ use openmls_traits::OpenMlsProvider; use prost::Message; use thiserror::Error; use tokio::sync::Mutex; +use xmtp_content_types::ContentType; use self::device_sync::DeviceSyncError; pub use self::group_permissions::PreconfiguredPolicies; @@ -67,7 +68,7 @@ use xmtp_proto::xmtp::mls::{ }, message_contents::{ plaintext_envelope::{Content, V1}, - PlaintextEnvelope, + EncodedContent, PlaintextEnvelope, }, }; @@ -309,6 +310,14 @@ pub enum UpdateAdminListType { RemoveSuper, } +/// Fields extracted from content of a message that should be stored in the DB +pub struct QueryableContentFields { + pub content_type: ContentType, + pub version_major: i32, + pub version_minor: i32, + pub authority_id: String, +} + /// Represents a group, which can contain anywhere from 1 to MAX_GROUP_SIZE inboxes. /// /// This is a wrapper around OpenMLS's `MlsGroup` that handles our application-level configuration @@ -706,6 +715,36 @@ impl MlsGroup { Ok(message_id) } + /// Helper function to extract queryable content fields from a message + fn extract_queryable_content_fields(message: &[u8]) -> QueryableContentFields { + let default = QueryableContentFields { + content_type: ContentType::Unknown, + version_major: 0, + version_minor: 0, + authority_id: "unknown".to_string(), + }; + + // Return early with default if decoding fails or type is missing + let content_type_id = match EncodedContent::decode(message) + .map_err(|e| tracing::debug!("Failed to decode message as EncodedContent: {}", e)) + .ok() + .and_then(|content| content.r#type) + { + Some(type_id) => type_id, + None => { + tracing::debug!("Message content type is missing"); + return default; + } + }; + + QueryableContentFields { + content_type: ContentType::from_string(&content_type_id.type_id), + version_major: content_type_id.version_major as i32, + version_minor: content_type_id.version_minor as i32, + authority_id: content_type_id.authority_id.to_string(), + } + } + /// Prepare a [`IntentKind::SendMessage`] intent, and [`StoredGroupMessage`] on this users XMTP [`Client`]. /// /// # Arguments @@ -734,6 +773,7 @@ impl MlsGroup { // store this unpublished message locally before sending let message_id = calculate_message_id(&self.group_id, message, &now.to_string()); + let queryable_content_fields = Self::extract_queryable_content_fields(message); let group_message = StoredGroupMessage { id: message_id.clone(), group_id: self.group_id.clone(), @@ -743,6 +783,10 @@ impl MlsGroup { sender_installation_id: self.context().installation_public_key().into(), sender_inbox_id: self.context().inbox_id().to_string(), delivery_status: DeliveryStatus::Unpublished, + content_type: queryable_content_fields.content_type, + version_major: queryable_content_fields.version_major, + version_minor: queryable_content_fields.version_minor, + authority_id: queryable_content_fields.authority_id, }; group_message.store(provider.conn_ref())?; diff --git a/xmtp_mls/src/storage/encrypted_store/group_message.rs b/xmtp_mls/src/storage/encrypted_store/group_message.rs index 743800d79..75ee41213 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_message.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_message.rs @@ -7,6 +7,7 @@ use diesel::{ sql_types::Integer, }; use serde::{Deserialize, Serialize}; +use xmtp_content_types::ContentType; use super::{ db_connection::DbConnection, @@ -38,6 +39,14 @@ pub struct StoredGroupMessage { pub sender_inbox_id: String, /// We optimistically store messages before sending. pub delivery_status: DeliveryStatus, + /// The Content Type of the message + pub content_type: ContentType, + /// The content type version major + pub version_major: i32, + /// The content type version minor + pub version_minor: i32, + /// The ID of the authority defining the content type + pub authority_id: String, } #[derive(Clone, Debug, PartialEq)] @@ -294,6 +303,7 @@ pub(crate) mod tests { kind: Option, group_id: Option<&[u8]>, sent_at_ns: Option, + content_type: Option, ) -> StoredGroupMessage { StoredGroupMessage { id: rand_vec::<24>(), @@ -304,6 +314,10 @@ pub(crate) mod tests { sender_inbox_id: "0x0".to_string(), kind: kind.unwrap_or(GroupMessageKind::Application), delivery_status: DeliveryStatus::Unpublished, + content_type: content_type.unwrap_or(ContentType::Unknown), + version_major: 0, + version_minor: 0, + authority_id: "unknown".to_string(), } } @@ -320,7 +334,7 @@ pub(crate) mod tests { async fn it_gets_messages() { with_connection(|conn| { let group = generate_group(None); - let message = generate_message(None, Some(&group.id), None); + let message = generate_message(None, Some(&group.id), None, None); group.store(conn).unwrap(); let id = message.id.clone(); @@ -337,7 +351,7 @@ pub(crate) mod tests { use diesel::result::{DatabaseErrorKind::ForeignKeyViolation, Error::DatabaseError}; with_connection(|conn| { - let message = generate_message(None, None, None); + let message = generate_message(None, None, None, None); assert_err!( message.store(conn), StorageError::DieselResult(DatabaseError(ForeignKeyViolation, _)) @@ -355,7 +369,7 @@ pub(crate) mod tests { group.store(conn).unwrap(); for idx in 0..50 { - let msg = generate_message(None, Some(&group.id), Some(idx)); + let msg = generate_message(None, Some(&group.id), Some(idx), None); assert_ok!(msg.store(conn)); } @@ -388,10 +402,10 @@ pub(crate) mod tests { group.store(conn).unwrap(); let messages = vec![ - generate_message(None, Some(&group.id), Some(1_000)), - generate_message(None, Some(&group.id), Some(100_000)), - generate_message(None, Some(&group.id), Some(10_000)), - generate_message(None, Some(&group.id), Some(1_000_000)), + generate_message(None, Some(&group.id), Some(1_000), None), + generate_message(None, Some(&group.id), Some(100_000), None), + generate_message(None, Some(&group.id), Some(10_000), None), + generate_message(None, Some(&group.id), Some(1_000_000), None), ]; assert_ok!(messages.store(conn)); let message = conn @@ -432,6 +446,7 @@ pub(crate) mod tests { Some(GroupMessageKind::Application), Some(&group.id), None, + Some(ContentType::Text), ); msg.store(conn).unwrap(); } @@ -440,6 +455,7 @@ pub(crate) mod tests { Some(GroupMessageKind::MembershipChange), Some(&group.id), None, + Some(ContentType::GroupMembershipChange), ); msg.store(conn).unwrap(); } @@ -472,10 +488,10 @@ pub(crate) mod tests { group.store(conn).unwrap(); let messages = vec![ - generate_message(None, Some(&group.id), Some(10_000)), - generate_message(None, Some(&group.id), Some(1_000)), - generate_message(None, Some(&group.id), Some(100_000)), - generate_message(None, Some(&group.id), Some(1_000_000)), + generate_message(None, Some(&group.id), Some(10_000), None), + generate_message(None, Some(&group.id), Some(1_000), None), + generate_message(None, Some(&group.id), Some(100_000), None), + generate_message(None, Some(&group.id), Some(1_000_000), None), ]; assert_ok!(messages.store(conn));