Skip to content

Commit

Permalink
Add codecs support and one new codec (#352)
Browse files Browse the repository at this point in the history
  • Loading branch information
neekolas authored Nov 29, 2023
1 parent 6034b38 commit 6626895
Show file tree
Hide file tree
Showing 9 changed files with 1,387 additions and 235 deletions.
80 changes: 80 additions & 0 deletions xmtp_mls/src/codecs/membership_change.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
use std::collections::HashMap;

use prost::Message;
use xmtp_proto::xmtp::mls::message_contents::{
ContentTypeId, EncodedContent, GroupMembershipChange,
};

use super::{CodecError, ContentCodec};

pub struct GroupMembershipChangeCodec {}

impl GroupMembershipChangeCodec {
const AUTHORITY_ID: &'static str = "xmtp.org";
const TYPE_ID: &'static str = "group_membership_change";
}

impl ContentCodec<GroupMembershipChange> for GroupMembershipChangeCodec {
fn content_type() -> ContentTypeId {
ContentTypeId {
authority_id: GroupMembershipChangeCodec::AUTHORITY_ID.to_string(),
type_id: GroupMembershipChangeCodec::TYPE_ID.to_string(),
version_major: 1,
version_minor: 0,
}
}

fn encode(data: GroupMembershipChange) -> Result<EncodedContent, CodecError> {
let mut buf = Vec::new();
data.encode(&mut buf)
.map_err(|e| CodecError::Encode(e.to_string()))?;

Ok(EncodedContent {
r#type: Some(GroupMembershipChangeCodec::content_type()),
parameters: HashMap::new(),
fallback: None,
compression: None,
content: buf,
})
}

fn decode(content: EncodedContent) -> Result<GroupMembershipChange, CodecError> {
let decoded = GroupMembershipChange::decode(content.content.as_slice())
.map_err(|e| CodecError::Decode(e.to_string()))?;

Ok(decoded)
}
}

#[cfg(test)]
mod tests {
use xmtp_proto::xmtp::mls::message_contents::Member;

use crate::utils::test::{rand_string, rand_vec};

use super::*;

#[test]
fn test_encode_decode() {
let new_member = Member {
installation_ids: vec![rand_vec()],
wallet_address: rand_string(),
};
let data = GroupMembershipChange {
members_added: vec![new_member.clone()],
members_removed: vec![],
installations_added: vec![],
installations_removed: vec![],
};

let encoded = GroupMembershipChangeCodec::encode(data).unwrap();
assert_eq!(
encoded.clone().r#type.unwrap().type_id,
"group_membership_change"
);
assert!(encoded.content.len() > 0);

let decoded = GroupMembershipChangeCodec::decode(encoded).unwrap();
assert_eq!(decoded.members_added[0], new_member);
}
}
19 changes: 19 additions & 0 deletions xmtp_mls/src/codecs/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
pub mod membership_change;
pub mod text;

use thiserror::Error;
use xmtp_proto::xmtp::mls::message_contents::{ContentTypeId, EncodedContent};

#[derive(Debug, Error)]
pub enum CodecError {
#[error("encode error {0}")]
Encode(String),
#[error("decode error {0}")]
Decode(String),
}

pub trait ContentCodec<T> {
fn content_type() -> ContentTypeId;
fn encode(content: T) -> Result<EncodedContent, CodecError>;
fn decode(content: EncodedContent) -> Result<T, CodecError>;
}
69 changes: 69 additions & 0 deletions xmtp_mls/src/codecs/text.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
use std::collections::HashMap;

use xmtp_proto::xmtp::mls::message_contents::{ContentTypeId, EncodedContent};

use super::{CodecError, ContentCodec};

pub struct TextCodec {}

impl TextCodec {
const AUTHORITY_ID: &'static str = "xmtp.org";
const TYPE_ID: &'static str = "text";
const ENCODING_KEY: &'static str = "encoding";
const ENCODING_UTF8: &'static str = "UTF-8";
}

impl ContentCodec<String> for TextCodec {
fn content_type() -> ContentTypeId {
ContentTypeId {
authority_id: TextCodec::AUTHORITY_ID.to_string(),
type_id: TextCodec::TYPE_ID.to_string(),
version_major: 1,
version_minor: 0,
}
}

fn encode(text: String) -> Result<EncodedContent, CodecError> {
Ok(EncodedContent {
r#type: Some(TextCodec::content_type()),
parameters: HashMap::from([(
TextCodec::ENCODING_KEY.to_string(),
TextCodec::ENCODING_UTF8.to_string(),
)]),
fallback: None,
compression: None,
content: text.into_bytes(),
})
}

fn decode(content: EncodedContent) -> Result<String, CodecError> {
let encoding = content
.parameters
.get(TextCodec::ENCODING_KEY)
.map_or(TextCodec::ENCODING_UTF8, String::as_str);
if encoding != TextCodec::ENCODING_UTF8 {
return Err(CodecError::Decode(format!(
"Unsupported text encoding {}",
encoding
)));
}
let text = std::str::from_utf8(&content.content)
.map_err(|utf8_err| CodecError::Decode(utf8_err.to_string()))?;
Ok(text.to_string())
}
}

#[cfg(test)]
mod tests {
use crate::codecs::{text::TextCodec, ContentCodec};

#[test]
fn can_encode_and_decode_text() {
let text = "Hello, world!";
let encoded_content =
TextCodec::encode(text.to_string()).expect("Should encode successfully");
let decoded_content =
TextCodec::decode(encoded_content).expect("Should decode successfully");
assert!(decoded_content == text);
}
}
25 changes: 12 additions & 13 deletions xmtp_mls/src/groups/intents.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ use prost::{DecodeError, Message};
use thiserror::Error;
use tls_codec::Serialize;
use xmtp_proto::xmtp::mls::database::{
add_members_publish_data::{Version as AddMembersVersion, V1 as AddMembersV1},
add_members_data::{Version as AddMembersVersion, V1 as AddMembersV1},
post_commit_action::{Kind as PostCommitActionKind, SendWelcomes as SendWelcomesProto},
remove_members_publish_data::{Version as RemoveMembersVersion, V1 as RemoveMembersV1},
send_message_publish_data::{Version as SendMessageVersion, V1 as SendMessageV1},
AddMembersPublishData, PostCommitAction as PostCommitActionProto, RemoveMembersPublishData,
SendMessagePublishData,
remove_members_data::{Version as RemoveMembersVersion, V1 as RemoveMembersV1},
send_message_data::{Version as SendMessageVersion, V1 as SendMessageV1},
AddMembersData, PostCommitAction as PostCommitActionProto, RemoveMembersData, SendMessageData,
};

use crate::{
Expand Down Expand Up @@ -40,7 +39,7 @@ impl SendMessageIntentData {

pub(crate) fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::new();
SendMessagePublishData {
SendMessageData {
version: Some(SendMessageVersion::V1(SendMessageV1 {
payload_bytes: self.message.clone(),
})),
Expand All @@ -52,7 +51,7 @@ impl SendMessageIntentData {
}

pub(crate) fn from_bytes(data: &[u8]) -> Result<Self, IntentError> {
let msg = SendMessagePublishData::decode(data)?;
let msg = SendMessageData::decode(data)?;
let payload_bytes = match msg.version {
Some(SendMessageVersion::V1(v1)) => v1.payload_bytes,
None => return Err(IntentError::Generic("missing payload".to_string())),
Expand Down Expand Up @@ -86,9 +85,9 @@ impl AddMembersIntentData {
.map(|kp| kp.inner.tls_serialize_detached())
.collect();

AddMembersPublishData {
AddMembersData {
version: Some(AddMembersVersion::V1(AddMembersV1 {
key_packages_bytes_tls_serialized: key_package_bytes_result?,
key_packages_bytes: key_package_bytes_result?,
})),
}
.encode(&mut buf)
Expand All @@ -101,9 +100,9 @@ impl AddMembersIntentData {
data: &[u8],
provider: &XmtpOpenMlsProvider,
) -> Result<Self, IntentError> {
let msg = AddMembersPublishData::decode(data)?;
let msg = AddMembersData::decode(data)?;
let key_package_bytes = match msg.version {
Some(AddMembersVersion::V1(v1)) => v1.key_packages_bytes_tls_serialized,
Some(AddMembersVersion::V1(v1)) => v1.key_packages_bytes,
None => return Err(IntentError::Generic("missing payload".to_string())),
};
let key_packages: Result<Vec<VerifiedKeyPackage>, KeyPackageVerificationError> =
Expand Down Expand Up @@ -138,7 +137,7 @@ impl RemoveMembersIntentData {
pub(crate) fn to_bytes(&self) -> Vec<u8> {
let mut buf = Vec::new();

RemoveMembersPublishData {
RemoveMembersData {
version: Some(RemoveMembersVersion::V1(RemoveMembersV1 {
installation_ids: self.installation_ids.clone(),
})),
Expand All @@ -150,7 +149,7 @@ impl RemoveMembersIntentData {
}

pub(crate) fn from_bytes(data: &[u8]) -> Result<Self, IntentError> {
let msg = RemoveMembersPublishData::decode(data)?;
let msg = RemoveMembersData::decode(data)?;
let installation_ids = match msg.version {
Some(RemoveMembersVersion::V1(v1)) => v1.installation_ids,
None => return Err(IntentError::Generic("missing payload".to_string())),
Expand Down
1 change: 1 addition & 0 deletions xmtp_mls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ pub mod api_client_wrapper;
pub mod association;
pub mod builder;
pub mod client;
pub mod codecs;
mod configuration;
pub mod groups;
pub mod identity;
Expand Down
Loading

0 comments on commit 6626895

Please sign in to comment.