diff --git a/xmtp/src/storage/encrypted_store/mod.rs b/xmtp/src/storage/encrypted_store/mod.rs index ec51fac07..002550d6e 100644 --- a/xmtp/src/storage/encrypted_store/mod.rs +++ b/xmtp/src/storage/encrypted_store/mod.rs @@ -9,7 +9,6 @@ //! if there are any outstanding database migrations and perform them as needed. When updating the //! table definitions `schema.rs` must also be updated. To generate the correct schemas you can run //! `diesel print-schema` or use `cargo run update-schema` which will update the files for you. -//! pub mod models; pub mod schema; diff --git a/xmtp_mls/src/builder.rs b/xmtp_mls/src/builder.rs index d40fa89c2..12bed8855 100644 --- a/xmtp_mls/src/builder.rs +++ b/xmtp_mls/src/builder.rs @@ -55,7 +55,7 @@ where provider: &XmtpOpenMlsProvider, ) -> Result { let identity_option: Option = - store.conn()?.fetch(())?.map(|i: StoredIdentity| i.into()); + store.conn()?.fetch(&())?.map(|i: StoredIdentity| i.into()); debug!("Existing identity in store: {:?}", identity_option); match self { IdentityStrategy::CachedOnly => { diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index 3c34623fc..294dbde8d 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -28,7 +28,7 @@ pub trait Store { pub trait Fetch { type Key; - fn fetch(&mut self, key: Self::Key) -> Result, StorageError>; + fn fetch(&mut self, key: &Self::Key) -> Result, StorageError>; } pub trait Delete { @@ -47,4 +47,54 @@ mod tests { tracing_subscriber::fmt::init(); }) } + + /// wrapper over assert!(matches!()) for Errors + /// assert_err!(fun(), StorageError::Explosion) + /// + /// or the message variant, + /// assert_err!(fun(), StorageError::Explosion, "the storage did not explode"); + #[macro_export] + macro_rules! assert_err { + ( $x:expr , $y:pat $(,)? ) => { + assert!(matches!($x, Err($y))); + }; + + ( $x:expr, $y:pat $(,)?, $($msg:tt)+) => {{ + assert!(matches!($x, Err($y)), $($msg)+) + }} + } + + /// wrapper over assert! macros for Ok's + /// + /// Make sure something is Ok(_) without caring about return value. + /// assert_ok!(fun()); + /// + /// Against an expected value, e.g Ok(true) + /// assert_ok!(fun(), true); + /// + /// or the message variant, + /// assert_ok!(fun(), Ok(_), "the storage is not ok"); + #[macro_export] + macro_rules! assert_ok { + + ( $e:expr ) => { + assert_ok!($e,) + }; + + ( $e:expr, ) => {{ + use std::result::Result::*; + match $e { + Ok(v) => v, + Err(e) => panic!("assertion failed: Err({:?})", e), + } + }}; + + ( $x:expr , $y:expr $(,)? ) => { + assert_eq!($x, Ok($y.into())); + }; + + ( $x:expr, $y:expr $(,)?, $($msg:tt)+) => {{ + assert_eq!($x, Ok($y.into()), $($msg)+); + }} + } } diff --git a/xmtp_mls/src/storage/encrypted_store/group.rs b/xmtp_mls/src/storage/encrypted_store/group.rs index fb96bd803..3caaaef47 100644 --- a/xmtp_mls/src/storage/encrypted_store/group.rs +++ b/xmtp_mls/src/storage/encrypted_store/group.rs @@ -10,9 +10,8 @@ use diesel::{ sqlite::Sqlite, }; -use super::schema::groups; -use crate::impl_fetch; -use crate::impl_store; +use super::{schema::groups, DbConnection, EncryptedMessageStore}; +use crate::{impl_fetch, impl_store, StorageError}; /// The Group ID type. pub type ID = Vec; @@ -43,6 +42,24 @@ impl StoredGroup { } } +impl EncryptedMessageStore { + /// Updates group membership state + pub fn update_group_membership>( + &self, + conn: &mut DbConnection, + id: GroupId, + state: GroupMembershipState, + ) -> Result<(), StorageError> { + use super::schema::groups::dsl; + + diesel::update(dsl::groups.find(id.as_ref())) + .set(dsl::membership_state.eq(state)) + .execute(conn)?; + + Ok(()) + } +} + #[repr(i32)] #[derive(Debug, Clone, Copy, PartialEq, Eq, AsExpression, FromSqlRow)] #[diesel(sql_type = Integer)] @@ -81,18 +98,30 @@ where } #[cfg(test)] -mod tests { +pub(crate) mod tests { use super::*; use crate::{ - storage::encrypted_store::{schema::groups::dsl::groups, tests::with_store}, + assert_ok, + storage::encrypted_store::{ + schema::groups::dsl::groups, + tests::{rand_time, rand_vec, with_store}, + }, Fetch, Store, }; + /// Generate a test group + pub fn generate_group(state: Option) -> StoredGroup { + StoredGroup { + id: rand_vec(), + created_at_ns: rand_time(), + membership_state: state.unwrap_or(GroupMembershipState::Allowed), + } + } + #[test] fn it_stores_group() { - with_store(|store| { - let mut conn = store.conn().unwrap(); - let test_group = StoredGroup::new(vec![0x0], 100, GroupMembershipState::Allowed); + with_store(|_, mut conn| { + let test_group = generate_group(None); test_group.store(&mut conn).unwrap(); assert_eq!(groups.first::(&mut conn).unwrap(), test_group); @@ -101,29 +130,30 @@ mod tests { #[test] fn it_fetches_group() { - with_store(|store| { - let mut conn = store.conn().unwrap(); - let test_group = StoredGroup::new(vec![0x0], 100, GroupMembershipState::Allowed); + with_store(|_, mut conn| { + let test_group = generate_group(None); + diesel::insert_into(groups) .values(test_group.clone()) .execute(&mut conn) .unwrap(); - let fetched_group = conn.fetch(vec![0x0]).ok().flatten().unwrap(); - assert_eq!(test_group, fetched_group); + + let fetched_group = Fetch::::fetch(&mut conn, &test_group.id); + assert_ok!(fetched_group, Some(test_group)); }) } #[test] fn it_updates_group_membership_state() { - with_store(|store| { - let id = vec![0x0]; - let mut conn = store.conn().unwrap(); - let test_group = StoredGroup::new(id.clone(), 100, GroupMembershipState::Pending); + with_store(|store, mut conn| { + let test_group = generate_group(Some(GroupMembershipState::Pending)); test_group.store(&mut conn).unwrap(); - let updated_group = store - .update_group_membership(&mut conn, id, GroupMembershipState::Rejected) + store + .update_group_membership(&mut conn, &test_group.id, GroupMembershipState::Rejected) .unwrap(); + + let updated_group: StoredGroup = conn.fetch(&test_group.id).ok().flatten().unwrap(); assert_eq!( updated_group, StoredGroup { diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index f4cdf85c6..512584a83 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -1,19 +1,20 @@ -use super::schema::group_intents::dsl; -use super::{group, schema::group_intents}; -use super::{DbConnection, EncryptedMessageStore}; -use crate::storage::StorageError; -use crate::{impl_fetch, impl_store}; - -use diesel::prelude::*; use diesel::{ backend::Backend, deserialize::{self, FromSql, FromSqlRow}, expression::AsExpression, + prelude::*, serialize::{self, IsNull, Output, ToSql}, sql_types::Integer, sqlite::Sqlite, }; +use super::{ + group, + schema::{group_intents, group_intents::dsl}, + DbConnection, EncryptedMessageStore, +}; +use crate::{impl_fetch, impl_store, storage::StorageError}; + pub type ID = i32; #[repr(i32)] @@ -98,7 +99,8 @@ impl EncryptedMessageStore { Ok(query.load::(conn)?) } - // Set the intent with the given ID to `Published` and set the payload hash. Optionally add `post_commit_data` + // Set the intent with the given ID to `Published` and set the payload hash. Optionally add + // `post_commit_data` pub fn set_group_intent_published( &self, conn: &mut DbConnection, @@ -108,7 +110,8 @@ impl EncryptedMessageStore { ) -> Result<(), StorageError> { let res = diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) - // State machine requires that the only valid state transition to Published is from ToPublish + // State machine requires that the only valid state transition to Published is from + // ToPublish .filter(dsl::state.eq(IntentState::ToPublish)) .set(( dsl::state.eq(IntentState::Published), @@ -132,7 +135,8 @@ impl EncryptedMessageStore { ) -> Result<(), StorageError> { let res = diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) - // State machine requires that the only valid state transition to Committed is from Published + // State machine requires that the only valid state transition to Committed is from + // Published .filter(dsl::state.eq(IntentState::Published)) .set(dsl::state.eq(IntentState::Committed)) .execute(conn)?; @@ -144,7 +148,8 @@ impl EncryptedMessageStore { } } - // Set the intent with the given ID to `ToPublish`. Wipe any values for `payload_hash` and `post_commit_data` + // Set the intent with the given ID to `ToPublish`. Wipe any values for `payload_hash` and + // `post_commit_data` pub fn set_group_intent_to_publish( &self, conn: &mut DbConnection, @@ -152,7 +157,8 @@ impl EncryptedMessageStore { ) -> Result<(), StorageError> { let res = diesel::update(dsl::group_intents) .filter(dsl::id.eq(intent_id)) - // State machine requires that the only valid state transition to ToPublish is from Published + // State machine requires that the only valid state transition to ToPublish is from + // Published .filter(dsl::state.eq(IntentState::Published)) .set(( dsl::state.eq(IntentState::ToPublish), @@ -169,7 +175,8 @@ impl EncryptedMessageStore { } } - // Simple lookup of intents by payload hash, meant to be used when processing messages off the network + // Simple lookup of intents by payload hash, meant to be used when processing messages off the + // network pub fn find_group_intent_by_payload_hash( &self, conn: &mut DbConnection, @@ -236,9 +243,13 @@ where #[cfg(test)] mod tests { use super::*; - use crate::storage::encrypted_store::group::{GroupMembershipState, StoredGroup}; - use crate::storage::encrypted_store::tests::{rand_vec, with_store}; - use crate::{Fetch, Store}; + use crate::{ + storage::encrypted_store::{ + group::{GroupMembershipState, StoredGroup}, + tests::{rand_vec, with_store}, + }, + Fetch, Store, + }; fn insert_group(conn: &mut DbConnection, group_id: Vec) { let group = StoredGroup::new(group_id, 100, GroupMembershipState::Allowed); @@ -246,7 +257,8 @@ mod tests { } impl NewGroupIntent { - // Real group intents must always start as ToPublish. But for tests we allow forcing the state + // Real group intents must always start as ToPublish. But for tests we allow forcing the + // state pub fn new_test( kind: IntentKind, group_id: Vec, @@ -278,9 +290,7 @@ mod tests { let to_insert = NewGroupIntent::new_test(kind, group_id.clone(), data.clone(), state); - with_store(|store| { - let mut conn = store.conn().unwrap(); - + with_store(|store, mut conn| { // Group needs to exist or FK constraint will fail insert_group(&mut conn, group_id.clone()); @@ -302,7 +312,7 @@ mod tests { let id = results[0].id; - let fetched: StoredGroupIntent = conn.fetch(id).unwrap().unwrap(); + let fetched: StoredGroupIntent = conn.fetch(&id).unwrap().unwrap(); assert_eq!(fetched.id, id); }) @@ -333,9 +343,7 @@ mod tests { ), ]; - with_store(|store| { - let mut conn = store.conn().unwrap(); - + with_store(|store, mut conn| { // Group needs to exist or FK constraint will fail insert_group(&mut conn, group_id.clone()); @@ -402,9 +410,7 @@ mod tests { fn find_by_payload_hash() { let group_id = rand_vec(); - with_store(|store| { - let mut conn = store.conn().unwrap(); - + with_store(|store, mut conn| { insert_group(&mut conn, group_id.clone()); // Store the intent @@ -440,9 +446,7 @@ mod tests { fn test_happy_path_state_transitions() { let group_id = rand_vec(); - with_store(|store| { - let mut conn = store.conn().unwrap(); - + with_store(|store, mut conn| { insert_group(&mut conn, group_id.clone()); // Store the intent @@ -464,7 +468,7 @@ mod tests { ) .unwrap(); - intent = conn.fetch(intent.id).unwrap().unwrap(); + intent = conn.fetch(&intent.id).unwrap().unwrap(); assert_eq!(intent.state, IntentState::Published); assert_eq!(intent.payload_hash, Some(payload_hash.clone())); assert_eq!(intent.post_commit_data, Some(post_commit_data.clone())); @@ -473,7 +477,7 @@ mod tests { .set_group_intent_committed(&mut conn, intent.id) .unwrap(); // Refresh from the DB - intent = conn.fetch(intent.id).unwrap().unwrap(); + intent = conn.fetch(&intent.id).unwrap().unwrap(); assert_eq!(intent.state, IntentState::Committed); // Make sure we haven't lost the payload hash assert_eq!(intent.payload_hash, Some(payload_hash.clone())); @@ -484,9 +488,7 @@ mod tests { fn test_republish_state_transition() { let group_id = rand_vec(); - with_store(|store| { - let mut conn = store.conn().unwrap(); - + with_store(|store, mut conn| { insert_group(&mut conn, group_id.clone()); // Store the intent @@ -508,7 +510,7 @@ mod tests { ) .unwrap(); - intent = conn.fetch(intent.id).unwrap().unwrap(); + intent = conn.fetch(&intent.id).unwrap().unwrap(); assert_eq!(intent.state, IntentState::Published); assert_eq!(intent.payload_hash, Some(payload_hash.clone())); @@ -516,7 +518,7 @@ mod tests { store .set_group_intent_to_publish(&mut conn, intent.id) .unwrap(); - intent = conn.fetch(intent.id).unwrap().unwrap(); + intent = conn.fetch(&intent.id).unwrap().unwrap(); assert_eq!(intent.state, IntentState::ToPublish); assert!(intent.payload_hash.is_none()); assert!(intent.post_commit_data.is_none()); @@ -527,9 +529,7 @@ mod tests { fn test_invalid_state_transition() { let group_id = rand_vec(); - with_store(|store| { - let mut conn = store.conn().unwrap(); - + with_store(|store, mut conn| { insert_group(&mut conn, group_id.clone()); // Store the intent diff --git a/xmtp_mls/src/storage/encrypted_store/group_message.rs b/xmtp_mls/src/storage/encrypted_store/group_message.rs index b9beb6788..1b387d295 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_message.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_message.rs @@ -1,10 +1,17 @@ -use diesel::prelude::*; +use diesel::{ + backend::Backend, + deserialize::{self, FromSql, FromSqlRow}, + expression::AsExpression, + prelude::*, + serialize::{self, IsNull, Output, ToSql}, + sql_types::Integer, + sqlite::Sqlite, +}; -use super::schema::group_messages; -use crate::impl_fetch; -use crate::impl_store; +use super::{schema::group_messages, DbConnection, EncryptedMessageStore}; +use crate::{impl_fetch, impl_store, StorageError}; -#[derive(Insertable, Identifiable, Queryable, Debug, Clone)] +#[derive(Insertable, Identifiable, Queryable, Debug, Clone, PartialEq, Eq)] #[diesel(table_name = group_messages)] #[diesel(primary_key(id))] /// Successfully processed messages to be returned to the User. @@ -18,12 +25,282 @@ pub struct StoredGroupMessage { /// Time in nanoseconds the message was sent. pub sent_at_ns: i64, /// Group Message Kind Enum - pub kind: i32, + pub kind: GroupMessageKind, /// The ID of the App Installation this message was sent from. pub sender_installation_id: Vec, /// Network wallet address of the Sender pub sender_wallet_address: String, } +#[repr(i32)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, AsExpression, FromSqlRow)] +#[diesel(sql_type = Integer)] +pub enum GroupMessageKind { + Application = 1, + MemberAdded = 2, + MemberRemoved = 3, +} + +impl ToSql for GroupMessageKind +where + i32: ToSql, +{ + fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result { + out.set_value(*self as i32); + Ok(IsNull::No) + } +} + +impl FromSql for GroupMessageKind +where + i32: FromSql, +{ + fn from_sql(bytes: ::RawValue<'_>) -> deserialize::Result { + match i32::from_sql(bytes)? { + 1 => Ok(GroupMessageKind::Application), + 2 => Ok(GroupMessageKind::MemberAdded), + 3 => Ok(GroupMessageKind::MemberRemoved), + x => Err(format!("Unrecognized variant {}", x).into()), + } + } +} + impl_fetch!(StoredGroupMessage, group_messages, Vec); impl_store!(StoredGroupMessage, group_messages); + +impl EncryptedMessageStore { + /// Query for group messages + pub fn get_group_messages>( + &self, + conn: &mut DbConnection, + group_id: GroupId, + sent_after: Option, + sent_before: Option, + kind: Option, + ) -> Result, StorageError> { + use super::schema::group_messages::dsl; + + let mut query = dsl::group_messages + .filter(dsl::group_id.eq(group_id.as_ref())) + .into_boxed(); + + if let Some(sent_after) = sent_after { + query = query.filter(dsl::sent_at_ns.gt(sent_after)); + } + + if let Some(sent_before) = sent_before { + query = query.filter(dsl::sent_at_ns.lt(sent_before)); + } + + if let Some(kind) = kind { + query = query.filter(dsl::kind.eq(kind)); + } + Ok(query.load::(conn)?) + } + + /// Get a particular group message + pub fn get_group_message>( + &self, + id: MessageId, + conn: &mut DbConnection, + ) -> Result, StorageError> { + use super::schema::group_messages::dsl; + Ok(dsl::group_messages + .filter(dsl::id.eq(id.as_ref())) + .first(conn) + .optional()?) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + assert_err, assert_ok, + storage::encrypted_store::{ + group::tests::generate_group, + tests::{rand_time, rand_vec, with_store}, + }, + Store, + }; + + fn generate_message( + kind: Option, + group_id: Option<&[u8]>, + sent_at_ns: Option, + ) -> StoredGroupMessage { + StoredGroupMessage { + id: rand_vec(), + group_id: group_id.map(<[u8]>::to_vec).unwrap_or(rand_vec()), + decrypted_message_bytes: rand_vec(), + sent_at_ns: sent_at_ns.unwrap_or(rand_time()), + sender_installation_id: rand_vec(), + sender_wallet_address: "0x0".to_string(), + kind: kind.unwrap_or(GroupMessageKind::Application), + } + } + + #[test] + fn it_does_not_error_on_empty_messages() { + with_store(|store, mut conn| { + let id = vec![0x0]; + assert_ok!(store.get_group_message(&id, &mut conn), None); + }) + } + + #[test] + fn it_gets_messages() { + with_store(|store, mut conn| { + let group = generate_group(None); + let message = generate_message(None, Some(&group.id), None); + group.store(&mut conn).unwrap(); + let id = message.id.clone(); + + message.store(&mut conn).unwrap(); + + let stored_message = store.get_group_message(&id, &mut conn); + assert_ok!(stored_message, Some(message)); + }) + } + + #[test] + fn it_cannot_insert_message_without_group() { + use diesel::result::{DatabaseErrorKind::ForeignKeyViolation, Error::DatabaseError}; + + with_store(|_, mut conn| { + let message = generate_message(None, None, None); + assert_err!( + message.store(&mut conn), + StorageError::DieselResult(DatabaseError(ForeignKeyViolation, _)) + ); + }) + } + + #[test] + fn it_gets_many_messages() { + use crate::storage::encrypted_store::schema::group_messages::dsl; + + with_store(|store, mut conn| { + let group = generate_group(None); + group.store(&mut conn).unwrap(); + + for _ in 0..50 { + let msg = generate_message(None, Some(&group.id), None); + assert_ok!(msg.store(&mut conn)); + } + + let count: i64 = dsl::group_messages + .select(diesel::dsl::count_star()) + .first(&mut conn) + .unwrap(); + assert_eq!(count, 50); + + let messages = store + .get_group_messages(&mut conn, &group.id, None, None, None) + .unwrap(); + assert_eq!(messages.len(), 50); + }) + } + + #[test] + fn it_gets_messages_by_time() { + with_store(|store, mut conn| { + let group = generate_group(None); + group.store(&mut conn).unwrap(); + + let messages = vec![ + generate_message(None, Some(&group.id), Some(1_000)), + generate_message(None, Some(&group.id), Some(10_000)), + generate_message(None, Some(&group.id), Some(100_000)), + generate_message(None, Some(&group.id), Some(1_000_000)), + ]; + assert_ok!(messages.store(&mut conn)); + let message = store + .get_group_messages(&mut conn, &group.id, Some(1_000), Some(100_000), None) + .unwrap(); + assert_eq!(message.len(), 1); + assert_eq!(message.first().unwrap().sent_at_ns, 10_000); + + let messages = store + .get_group_messages(&mut conn, &group.id, None, Some(100_000), None) + .unwrap(); + assert_eq!(messages.len(), 2); + + let messages = store + .get_group_messages(&mut conn, &group.id, Some(10_000), None, None) + .unwrap(); + assert_eq!(messages.len(), 2); + }) + } + + #[test] + fn it_gets_messages_by_kind() { + with_store(|store, mut conn| { + let group = generate_group(None); + group.store(&mut conn).unwrap(); + + // just a bunch of random messages so we have something to filter through + for i in 0..30 { + match i % 3 { + 0 => { + let msg = generate_message( + Some(GroupMessageKind::Application), + Some(&group.id), + None, + ); + msg.store(&mut conn).unwrap(); + } + 1 => { + let msg = generate_message( + Some(GroupMessageKind::MemberRemoved), + Some(&group.id), + None, + ); + msg.store(&mut conn).unwrap(); + } + 2 | _ => { + let msg = generate_message( + Some(GroupMessageKind::MemberAdded), + Some(&group.id), + None, + ); + msg.store(&mut conn).unwrap(); + } + } + } + + let application_messages = store + .get_group_messages( + &mut conn, + &group.id, + None, + None, + Some(GroupMessageKind::Application), + ) + .unwrap(); + assert_eq!(application_messages.len(), 10); + + let member_removed = store + .get_group_messages( + &mut conn, + &group.id, + None, + None, + Some(GroupMessageKind::MemberAdded), + ) + .unwrap(); + assert_eq!(member_removed.len(), 10); + + let member_added = store + .get_group_messages( + &mut conn, + &group.id, + None, + None, + Some(GroupMessageKind::MemberRemoved), + ) + .unwrap(); + assert_eq!(member_added.len(), 10); + }) + } +} diff --git a/xmtp_mls/src/storage/encrypted_store/mod.rs b/xmtp_mls/src/storage/encrypted_store/mod.rs index 9bd33f374..269174954 100644 --- a/xmtp_mls/src/storage/encrypted_store/mod.rs +++ b/xmtp_mls/src/storage/encrypted_store/mod.rs @@ -9,7 +9,6 @@ //! if there are any outstanding database migrations and perform them as needed. When updating the //! table definitions `schema.rs` must also be updated. To generate the correct schemas you can run //! `diesel print-schema` or use `cargo run update-schema` which will update the files for you. -//! pub mod group; pub mod group_intent; @@ -28,12 +27,12 @@ use diesel::{ result::{DatabaseErrorKind, Error}, }; use diesel_migrations::{embed_migrations, EmbeddedMigrations, MigrationHarness}; -use group::{GroupMembershipState, StoredGroup}; use log::warn; use rand::RngCore; use xmtp_cryptography::utils as crypto_utils; use super::StorageError; +use crate::Store; pub const MIGRATIONS: EmbeddedMigrations = embed_migrations!("./migrations/"); @@ -161,21 +160,6 @@ impl EncryptedMessageStore { crypto_utils::rng().fill_bytes(&mut key[..]); key } - - /// Updates group membership state - pub fn update_group_membership( - &self, - conn: &mut DbConnection, - id: group::ID, - state: GroupMembershipState, - ) -> Result { - use self::schema::groups::dsl::{groups, membership_state}; - - diesel::update(groups.find(id)) - .set(membership_state.eq(state)) - .get_result::(conn) - .map_err(Into::into) - } } #[allow(dead_code)] @@ -195,7 +179,7 @@ macro_rules! impl_fetch { ($model:ty, $table:ident) => { impl $crate::Fetch<$model> for $crate::storage::encrypted_store::DbConnection { type Key = (); - fn fetch(&mut self, _key: Self::Key) -> Result, $crate::StorageError> { + fn fetch(&mut self, _key: &Self::Key) -> Result, $crate::StorageError> { use $crate::storage::encrypted_store::schema::$table::dsl::*; Ok($table.first(self).optional()?) } @@ -205,7 +189,7 @@ macro_rules! impl_fetch { ($model:ty, $table:ident, $key:ty) => { impl $crate::Fetch<$model> for $crate::storage::encrypted_store::DbConnection { type Key = $key; - fn fetch(&mut self, key: Self::Key) -> Result, $crate::StorageError> { + fn fetch(&mut self, key: &Self::Key) -> Result, $crate::StorageError> { use $crate::storage::encrypted_store::schema::$table::dsl::*; Ok($table.find(key).first(self).optional()?) } @@ -231,6 +215,18 @@ macro_rules! impl_store { }; } +impl Store for Vec +where + T: Store, +{ + fn store(&self, into: &mut DbConnection) -> Result<(), StorageError> { + for item in self { + item.store(into)?; + } + Ok(()) + } +} + #[cfg(test)] mod tests { use std::{boxed::Box, fs}; @@ -251,10 +247,15 @@ mod tests { rand::thread_rng().gen::<[u8; 16]>().to_vec() } + pub(crate) fn rand_time() -> i64 { + let mut rng = rand::thread_rng(); + rng.gen_range(0..1_000_000_000) + } + /// Test harness that loads an Ephemeral store. pub fn with_store(fun: F) -> R where - F: FnOnce(EncryptedMessageStore) -> R, + F: FnOnce(EncryptedMessageStore, super::DbConnection) -> R, { crate::tests::setup(); let store = EncryptedMessageStore::new( @@ -262,7 +263,8 @@ mod tests { EncryptedMessageStore::generate_enc_key(), ) .unwrap(); - fun(store) + let conn = store.conn().expect("acquiring a Connection failed"); + fun(store, conn) } #[test] @@ -279,7 +281,7 @@ mod tests { .store(conn) .unwrap(); - let fetched_identity: StoredIdentity = conn.fetch(()).unwrap().unwrap(); + let fetched_identity: StoredIdentity = conn.fetch(&()).unwrap().unwrap(); assert_eq!(fetched_identity.account_address, account_address); } @@ -299,7 +301,7 @@ mod tests { .store(conn) .unwrap(); - let fetched_identity: StoredIdentity = conn.fetch(()).unwrap().unwrap(); + let fetched_identity: StoredIdentity = conn.fetch(&()).unwrap().unwrap(); assert_eq!(fetched_identity.account_address, account_address); } diff --git a/xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs b/xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs index 7c181589e..ee4f42ab5 100644 --- a/xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs +++ b/xmtp_mls/src/storage/encrypted_store/topic_refresh_state.rs @@ -1,8 +1,7 @@ use diesel::prelude::*; use super::schema::topic_refresh_state; -use crate::impl_fetch; -use crate::impl_store; +use crate::{impl_fetch, impl_store}; #[derive(Insertable, Identifiable, Queryable, Debug, Clone)] #[diesel(table_name = topic_refresh_state)] diff --git a/xmtp_mls/src/storage/sql_key_store.rs b/xmtp_mls/src/storage/sql_key_store.rs index 170d3fd0f..68b1f6812 100644 --- a/xmtp_mls/src/storage/sql_key_store.rs +++ b/xmtp_mls/src/storage/sql_key_store.rs @@ -60,7 +60,7 @@ impl OpenMlsKeyStore for SqlKeyStore<'_> { return None; } let mut conn = conn_result.unwrap(); - let fetch_result = conn.fetch(k.to_vec()); + let fetch_result = conn.fetch(&k.to_vec()); if let Err(e) = fetch_result { error!("Failed to fetch key: {:?}", e); return None;