Skip to content

Commit

Permalink
Saving content types to db
Browse files Browse the repository at this point in the history
  • Loading branch information
cameronvoell committed Dec 19, 2024
1 parent 6c7910e commit c2c28b4
Show file tree
Hide file tree
Showing 6 changed files with 173 additions and 19 deletions.
2 changes: 2 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions xmtp_content_types/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ license.workspace = true
thiserror = { workspace = true }
prost = { workspace = true, features = ["prost-derive"] }
rand = { workspace = true }
diesel = { workspace = true }
serde = { workspace = true }

# XMTP/Local
xmtp_proto = { workspace = true, features = ["convert"] }
Expand Down
67 changes: 64 additions & 3 deletions xmtp_content_types/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,74 @@ pub mod group_updated;
pub mod membership_change;
pub mod text;

use diesel::{
backend::Backend,
deserialize::{self, FromSql, FromSqlRow},
expression::AsExpression,
serialize::{self, IsNull, Output, ToSql},
sql_types::Integer,
sqlite::Sqlite,
};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use xmtp_proto::xmtp::mls::message_contents::{ContentTypeId, EncodedContent};

/// ContentType and their corresponding string representation
/// are derived from the `ContentTypeId` enum in the xmtp-proto crate
/// that each content type in this crate establishes for itself
#[repr(i32)]
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, FromSqlRow, AsExpression)]
#[diesel(sql_type = diesel::sql_types::Integer)]
pub enum ContentType {
GroupMembershipChange,
GroupUpdated,
Text,
Unknown = 0,
Text = 1,
GroupMembershipChange = 2,
GroupUpdated = 3,
}

impl ContentType {
pub fn from_string(type_id: &str) -> Self {
match type_id {
"text" => Self::Text,
"group_membership_change" => Self::GroupMembershipChange,
"group_updated" => Self::GroupUpdated,
_ => Self::Unknown,
}
}

pub fn to_string(&self) -> &'static str {
match self {
Self::Unknown => "unknown",
Self::Text => "text",
Self::GroupMembershipChange => "group_membership_change",
Self::GroupUpdated => "group_updated",
}
}
}

impl ToSql<Integer, Sqlite> for ContentType
where
i32: ToSql<Integer, Sqlite>,
{
fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result {
out.set_value(*self as i32);
Ok(IsNull::No)
}
}

impl FromSql<Integer, Sqlite> for ContentType
where
i32: FromSql<Integer, Sqlite>,
{
fn from_sql(bytes: <Sqlite as Backend>::RawValue<'_>) -> deserialize::Result<Self> {
match i32::from_sql(bytes)? {
0 => Ok(ContentType::Unknown),
1 => Ok(ContentType::Text),
2 => Ok(ContentType::GroupMembershipChange),
3 => Ok(ContentType::GroupUpdated),
x => Err(format!("Unrecognized variant {}", x).into()),
}
}
}

#[derive(Debug, Error)]
Expand Down
37 changes: 33 additions & 4 deletions xmtp_mls/src/groups/mls_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ use hmac::{Hmac, Mac};
use openmls::{
credentials::BasicCredential,
extensions::Extensions,
framing::{ContentType, ProtocolMessage},
framing::{ContentType as MlsContentType, ProtocolMessage},
group::{GroupEpoch, StagedCommit},
key_packages::KeyPackage,
prelude::{
Expand All @@ -67,7 +67,7 @@ use std::{
use thiserror::Error;
use tracing::debug;
use xmtp_common::{retry_async, Retry, RetryableError};
use xmtp_content_types::{group_updated::GroupUpdatedCodec, CodecError, ContentCodec};
use xmtp_content_types::{group_updated::GroupUpdatedCodec, CodecError, ContentCodec, ContentType};
use xmtp_id::{InboxId, InboxIdRef};
use xmtp_proto::xmtp::mls::{
api::v1::{
Expand Down Expand Up @@ -546,6 +546,7 @@ where
})) => {
let message_id =
calculate_message_id(&self.group_id, &content, &idempotency_key);
let queryable_content_fields = Self::extract_queryable_content_fields(&content);
StoredGroupMessage {
id: message_id,
group_id: self.group_id.clone(),
Expand All @@ -555,6 +556,10 @@ where
sender_installation_id,
sender_inbox_id,
delivery_status: DeliveryStatus::Published,
content_type: queryable_content_fields.content_type,
version_major: queryable_content_fields.version_major,
version_minor: queryable_content_fields.version_minor,
authority_id: queryable_content_fields.authority_id,
}
.store_or_ignore(provider.conn_ref())?
}
Expand Down Expand Up @@ -583,6 +588,10 @@ where
sender_installation_id,
sender_inbox_id: sender_inbox_id.clone(),
delivery_status: DeliveryStatus::Published,
content_type: ContentType::Unknown,
version_major: 0,
version_minor: 0,
authority_id: "unknown".to_string(),
}
.store_or_ignore(provider.conn_ref())?;

Expand Down Expand Up @@ -612,6 +621,10 @@ where
sender_installation_id,
sender_inbox_id,
delivery_status: DeliveryStatus::Published,
content_type: ContentType::Unknown,
version_major: 0,
version_minor: 0,
authority_id: "unknown".to_string(),
}
.store_or_ignore(provider.conn_ref())?;

Expand Down Expand Up @@ -712,7 +725,7 @@ where
discriminant(&other),
)),
}?;
if !allow_epoch_increment && message.content_type() == ContentType::Commit {
if !allow_epoch_increment && message.content_type() == MlsContentType::Commit {
return Err(GroupMessageProcessingError::EpochIncrementNotAllowed);
}

Expand Down Expand Up @@ -933,7 +946,19 @@ where
encoded_payload_bytes.as_slice(),
&timestamp_ns.to_string(),
);

let content_type = match encoded_payload.r#type {
Some(ct) => ct,
None => {
tracing::warn!("Missing content type in encoded payload, using default values");
// Default content type values
xmtp_proto::xmtp::mls::message_contents::ContentTypeId {
authority_id: "unknown".to_string(),
type_id: "unknown".to_string(),
version_major: 0,
version_minor: 0,
}
}
};
let msg = StoredGroupMessage {
id: message_id,
group_id: group_id.to_vec(),
Expand All @@ -943,6 +968,10 @@ where
sender_installation_id,
sender_inbox_id,
delivery_status: DeliveryStatus::Published,
content_type: ContentType::from_string(&content_type.type_id),
version_major: content_type.version_major as i32,
version_minor: content_type.version_minor as i32,
authority_id: content_type.authority_id.to_string(),
};

msg.store_or_ignore(conn)?;
Expand Down
46 changes: 45 additions & 1 deletion xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ use openmls_traits::OpenMlsProvider;
use prost::Message;
use thiserror::Error;
use tokio::sync::Mutex;
use xmtp_content_types::ContentType;

use self::device_sync::DeviceSyncError;
pub use self::group_permissions::PreconfiguredPolicies;
Expand Down Expand Up @@ -67,7 +68,7 @@ use xmtp_proto::xmtp::mls::{
},
message_contents::{
plaintext_envelope::{Content, V1},
PlaintextEnvelope,
EncodedContent, PlaintextEnvelope,
},
};

Expand Down Expand Up @@ -309,6 +310,14 @@ pub enum UpdateAdminListType {
RemoveSuper,
}

/// Fields extracted from content of a message that should be stored in the DB
pub struct QueryableContentFields {
pub content_type: ContentType,
pub version_major: i32,
pub version_minor: i32,
pub authority_id: String,
}

/// Represents a group, which can contain anywhere from 1 to MAX_GROUP_SIZE inboxes.
///
/// This is a wrapper around OpenMLS's `MlsGroup` that handles our application-level configuration
Expand Down Expand Up @@ -706,6 +715,36 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
Ok(message_id)
}

/// Helper function to extract queryable content fields from a message
fn extract_queryable_content_fields(message: &[u8]) -> QueryableContentFields {
let default = QueryableContentFields {
content_type: ContentType::Unknown,
version_major: 0,
version_minor: 0,
authority_id: "unknown".to_string(),
};

// Return early with default if decoding fails or type is missing
let content_type_id = match EncodedContent::decode(message)
.map_err(|e| tracing::debug!("Failed to decode message as EncodedContent: {}", e))
.ok()
.and_then(|content| content.r#type)
{
Some(type_id) => type_id,
None => {
tracing::debug!("Message content type is missing");
return default;
}
};

QueryableContentFields {
content_type: ContentType::from_string(&content_type_id.type_id),
version_major: content_type_id.version_major as i32,
version_minor: content_type_id.version_minor as i32,
authority_id: content_type_id.authority_id.to_string(),
}
}

/// Prepare a [`IntentKind::SendMessage`] intent, and [`StoredGroupMessage`] on this users XMTP [`Client`].
///
/// # Arguments
Expand Down Expand Up @@ -734,6 +773,7 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {

// store this unpublished message locally before sending
let message_id = calculate_message_id(&self.group_id, message, &now.to_string());
let queryable_content_fields = Self::extract_queryable_content_fields(message);
let group_message = StoredGroupMessage {
id: message_id.clone(),
group_id: self.group_id.clone(),
Expand All @@ -743,6 +783,10 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
sender_installation_id: self.context().installation_public_key().into(),
sender_inbox_id: self.context().inbox_id().to_string(),
delivery_status: DeliveryStatus::Unpublished,
content_type: queryable_content_fields.content_type,
version_major: queryable_content_fields.version_major,
version_minor: queryable_content_fields.version_minor,
authority_id: queryable_content_fields.authority_id,
};
group_message.store(provider.conn_ref())?;

Expand Down
Loading

0 comments on commit c2c28b4

Please sign in to comment.