diff --git a/xmtp_mls/migrations/2023-10-29-205333_state_machine_init/up.sql b/xmtp_mls/migrations/2023-10-29-205333_state_machine_init/up.sql index 7aadc7b03..c9363593d 100644 --- a/xmtp_mls/migrations/2023-10-29-205333_state_machine_init/up.sql +++ b/xmtp_mls/migrations/2023-10-29-205333_state_machine_init/up.sql @@ -49,10 +49,10 @@ CREATE TABLE group_intents ( -- INTENT_STATE, "state" INT NOT NULL, -- The hash of the encrypted, concrete, form of the message if it was published. - "payload_hash" BLOB, + "payload_hash" BLOB UNIQUE, -- (Optional) data needed for the post-commit flow. For example, welcome messages "post_commit_data" BLOB, FOREIGN KEY (group_id) REFERENCES groups(id) ); -CREATE INDEX group_intents_group_id_id ON group_intents(group_id, id); \ No newline at end of file +CREATE INDEX group_intents_group_id_state ON group_intents(group_id, state); \ No newline at end of file diff --git a/xmtp_mls/src/storage/encrypted_store/group_intent.rs b/xmtp_mls/src/storage/encrypted_store/group_intent.rs index 2ed68111b..f4cdf85c6 100644 --- a/xmtp_mls/src/storage/encrypted_store/group_intent.rs +++ b/xmtp_mls/src/storage/encrypted_store/group_intent.rs @@ -1,25 +1,551 @@ +use super::schema::group_intents::dsl; +use super::{group, schema::group_intents}; +use super::{DbConnection, EncryptedMessageStore}; +use crate::storage::StorageError; +use crate::{impl_fetch, impl_store}; + use diesel::prelude::*; +use diesel::{ + backend::Backend, + deserialize::{self, FromSql, FromSqlRow}, + expression::AsExpression, + serialize::{self, IsNull, Output, ToSql}, + sql_types::Integer, + sqlite::Sqlite, +}; + +pub type ID = i32; + +#[repr(i32)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, AsExpression, FromSqlRow)] +#[diesel(sql_type = Integer)] +/// Status of membership in a group, once a user sends a request to join +pub enum IntentKind { + SendMessage = 1, + AddMembers = 2, + RemoveMembers = 3, + KeyUpdate = 4, +} -use super::schema::group_intents; +#[repr(i32)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, AsExpression, FromSqlRow)] +#[diesel(sql_type = Integer)] +/// Status of membership in a group, once a user sends a request to join +pub enum IntentState { + ToPublish = 1, + Published = 2, + Committed = 3, +} -#[derive(Queryable, Identifiable, Debug, Clone)] +#[derive(Queryable, Identifiable, Debug, PartialEq, Clone)] #[diesel(table_name = group_intents)] #[diesel(primary_key(id))] pub struct StoredGroupIntent { - pub id: i32, - pub kind: i32, - pub group_id: Vec, + pub id: ID, + pub kind: IntentKind, + pub group_id: group::ID, pub data: Vec, - pub state: i32, + pub state: IntentState, pub payload_hash: Option>, pub post_commit_data: Option>, } -#[derive(Insertable, Debug, Clone)] +impl_fetch!(StoredGroupIntent, group_intents, ID); + +#[derive(Insertable, Debug, PartialEq, Clone)] #[diesel(table_name = group_intents)] pub struct NewGroupIntent { - pub kind: i32, + pub kind: IntentKind, pub group_id: Vec, pub data: Vec, - pub state: i32, + pub state: IntentState, +} + +impl_store!(NewGroupIntent, group_intents); + +impl NewGroupIntent { + pub fn new(kind: IntentKind, group_id: Vec, data: Vec) -> Self { + Self { + kind, + group_id, + data, + state: IntentState::ToPublish, + } + } +} + +impl EncryptedMessageStore { + // Query for group_intents by group_id, optionally filtering by state and kind + pub fn find_group_intents( + &self, + conn: &mut DbConnection, + group_id: Vec, + allowed_states: Option>, + allowed_kinds: Option>, + ) -> Result, StorageError> { + let mut query = dsl::group_intents + .into_boxed() + .filter(dsl::group_id.eq(group_id)); + + if let Some(allowed_states) = allowed_states { + query = query.filter(dsl::state.eq_any(allowed_states)); + } + + if let Some(allowed_kinds) = allowed_kinds { + query = query.filter(dsl::kind.eq_any(allowed_kinds)); + } + + Ok(query.load::(conn)?) + } + + // Set the intent with the given ID to `Published` and set the payload hash. Optionally add `post_commit_data` + pub fn set_group_intent_published( + &self, + conn: &mut DbConnection, + intent_id: ID, + payload_hash: Vec, + post_commit_data: Option>, + ) -> Result<(), StorageError> { + let res = diesel::update(dsl::group_intents) + .filter(dsl::id.eq(intent_id)) + // State machine requires that the only valid state transition to Published is from ToPublish + .filter(dsl::state.eq(IntentState::ToPublish)) + .set(( + dsl::state.eq(IntentState::Published), + dsl::payload_hash.eq(payload_hash), + dsl::post_commit_data.eq(post_commit_data), + )) + .execute(conn)?; + + match res { + // If nothing matched the query, return an error. Either ID or state was wrong + 0 => Err(StorageError::NotFound), + _ => Ok(()), + } + } + + // Set the intent with the given ID to `Committed` + pub fn set_group_intent_committed( + &self, + conn: &mut DbConnection, + intent_id: ID, + ) -> Result<(), StorageError> { + let res = diesel::update(dsl::group_intents) + .filter(dsl::id.eq(intent_id)) + // State machine requires that the only valid state transition to Committed is from Published + .filter(dsl::state.eq(IntentState::Published)) + .set(dsl::state.eq(IntentState::Committed)) + .execute(conn)?; + + match res { + // If nothing matched the query, return an error. Either ID or state was wrong + 0 => Err(StorageError::NotFound), + _ => Ok(()), + } + } + + // Set the intent with the given ID to `ToPublish`. Wipe any values for `payload_hash` and `post_commit_data` + pub fn set_group_intent_to_publish( + &self, + conn: &mut DbConnection, + intent_id: ID, + ) -> Result<(), StorageError> { + let res = diesel::update(dsl::group_intents) + .filter(dsl::id.eq(intent_id)) + // State machine requires that the only valid state transition to ToPublish is from Published + .filter(dsl::state.eq(IntentState::Published)) + .set(( + dsl::state.eq(IntentState::ToPublish), + // When moving to ToPublish, clear the payload hash and post commit data + dsl::payload_hash.eq(None::>), + dsl::post_commit_data.eq(None::>), + )) + .execute(conn)?; + + match res { + // If nothing matched the query, return an error. Either ID or state was wrong + 0 => Err(StorageError::NotFound), + _ => Ok(()), + } + } + + // Simple lookup of intents by payload hash, meant to be used when processing messages off the network + pub fn find_group_intent_by_payload_hash( + &self, + conn: &mut DbConnection, + payload_hash: Vec, + ) -> Result, StorageError> { + let result = dsl::group_intents + .filter(dsl::payload_hash.eq(payload_hash)) + .first::(conn) + .optional()?; + + Ok(result) + } +} + +impl ToSql for IntentKind +where + i32: ToSql, +{ + fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result { + out.set_value(*self as i32); + Ok(IsNull::No) + } +} + +impl FromSql for IntentKind +where + i32: FromSql, +{ + fn from_sql(bytes: ::RawValue<'_>) -> deserialize::Result { + match i32::from_sql(bytes)? { + 1 => Ok(IntentKind::SendMessage), + 2 => Ok(IntentKind::AddMembers), + 3 => Ok(IntentKind::RemoveMembers), + 4 => Ok(IntentKind::KeyUpdate), + x => Err(format!("Unrecognized variant {}", x).into()), + } + } +} + +impl ToSql for IntentState +where + i32: ToSql, +{ + fn to_sql<'b>(&'b self, out: &mut Output<'b, '_, Sqlite>) -> serialize::Result { + out.set_value(*self as i32); + Ok(IsNull::No) + } +} + +impl FromSql for IntentState +where + i32: FromSql, +{ + fn from_sql(bytes: ::RawValue<'_>) -> deserialize::Result { + match i32::from_sql(bytes)? { + 1 => Ok(IntentState::ToPublish), + 2 => Ok(IntentState::Published), + 3 => Ok(IntentState::Committed), + x => Err(format!("Unrecognized variant {}", x).into()), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::storage::encrypted_store::group::{GroupMembershipState, StoredGroup}; + use crate::storage::encrypted_store::tests::{rand_vec, with_store}; + use crate::{Fetch, Store}; + + fn insert_group(conn: &mut DbConnection, group_id: Vec) { + let group = StoredGroup::new(group_id, 100, GroupMembershipState::Allowed); + group.store(conn).unwrap(); + } + + impl NewGroupIntent { + // Real group intents must always start as ToPublish. But for tests we allow forcing the state + pub fn new_test( + kind: IntentKind, + group_id: Vec, + data: Vec, + state: IntentState, + ) -> Self { + Self { + kind, + group_id, + data, + state, + } + } + } + + fn find_first_intent(conn: &mut DbConnection, group_id: group::ID) -> StoredGroupIntent { + dsl::group_intents + .filter(dsl::group_id.eq(group_id)) + .first(conn) + .unwrap() + } + + #[test] + fn test_store_and_fetch() { + let group_id = rand_vec(); + let data = rand_vec(); + let kind = IntentKind::AddMembers; + let state = IntentState::ToPublish; + + let to_insert = NewGroupIntent::new_test(kind, group_id.clone(), data.clone(), state); + + with_store(|store| { + let mut conn = store.conn().unwrap(); + + // Group needs to exist or FK constraint will fail + insert_group(&mut conn, group_id.clone()); + + to_insert.store(&mut conn).unwrap(); + + let results = store + .find_group_intents( + &mut conn, + group_id.clone(), + Some(vec![IntentState::ToPublish]), + None, + ) + .unwrap(); + + assert_eq!(results.len(), 1); + assert_eq!(results[0].kind, kind); + assert_eq!(results[0].data, data); + assert_eq!(results[0].group_id, group_id); + + let id = results[0].id; + + let fetched: StoredGroupIntent = conn.fetch(id).unwrap().unwrap(); + + assert_eq!(fetched.id, id); + }) + } + + #[test] + fn test_query() { + let group_id = rand_vec(); + + let test_intents: Vec = vec![ + NewGroupIntent::new_test( + IntentKind::AddMembers, + group_id.clone(), + rand_vec(), + IntentState::ToPublish, + ), + NewGroupIntent::new_test( + IntentKind::RemoveMembers, + group_id.clone(), + rand_vec(), + IntentState::Published, + ), + NewGroupIntent::new_test( + IntentKind::RemoveMembers, + group_id.clone(), + rand_vec(), + IntentState::Committed, + ), + ]; + + with_store(|store| { + let mut conn = store.conn().unwrap(); + + // Group needs to exist or FK constraint will fail + insert_group(&mut conn, group_id.clone()); + + for case in test_intents { + case.store(&mut conn).unwrap(); + } + + // Can query for multiple states + let mut results = store + .find_group_intents( + &mut conn, + group_id.clone(), + Some(vec![IntentState::ToPublish, IntentState::Published]), + None, + ) + .unwrap(); + + assert_eq!(results.len(), 2); + + // Can query by kind + results = store + .find_group_intents( + &mut conn, + group_id.clone(), + None, + Some(vec![IntentKind::RemoveMembers]), + ) + .unwrap(); + assert_eq!(results.len(), 2); + + // Can query by kind and state + results = store + .find_group_intents( + &mut conn, + group_id.clone(), + Some(vec![IntentState::Committed]), + Some(vec![IntentKind::RemoveMembers]), + ) + .unwrap(); + + assert_eq!(results.len(), 1); + + // Can get no results + results = store + .find_group_intents( + &mut conn, + group_id.clone(), + Some(vec![IntentState::Committed]), + Some(vec![IntentKind::SendMessage]), + ) + .unwrap(); + + assert_eq!(results.len(), 0); + + // Can get all intents + results = store + .find_group_intents(&mut conn, group_id, None, None) + .unwrap(); + assert_eq!(results.len(), 3); + }) + } + + #[test] + fn find_by_payload_hash() { + let group_id = rand_vec(); + + with_store(|store| { + let mut conn = store.conn().unwrap(); + + insert_group(&mut conn, group_id.clone()); + + // Store the intent + NewGroupIntent::new(IntentKind::AddMembers, group_id.clone(), rand_vec()) + .store(&mut conn) + .unwrap(); + + // Find the intent with the ID populated + let intent = find_first_intent(&mut conn, group_id.clone()); + + // Set the payload hash + let payload_hash = rand_vec(); + let post_commit_data = rand_vec(); + store + .set_group_intent_published( + &mut conn, + intent.id, + payload_hash.clone(), + Some(post_commit_data.clone()), + ) + .unwrap(); + + let find_result = store + .find_group_intent_by_payload_hash(&mut conn, payload_hash) + .unwrap() + .unwrap(); + + assert_eq!(find_result.id, intent.id); + }) + } + + #[test] + fn test_happy_path_state_transitions() { + let group_id = rand_vec(); + + with_store(|store| { + let mut conn = store.conn().unwrap(); + + insert_group(&mut conn, group_id.clone()); + + // Store the intent + NewGroupIntent::new(IntentKind::AddMembers, group_id.clone(), rand_vec()) + .store(&mut conn) + .unwrap(); + + let mut intent = find_first_intent(&mut conn, group_id.clone()); + + // Set to published + let payload_hash = rand_vec(); + let post_commit_data = rand_vec(); + store + .set_group_intent_published( + &mut conn, + intent.id, + payload_hash.clone(), + Some(post_commit_data.clone()), + ) + .unwrap(); + + intent = conn.fetch(intent.id).unwrap().unwrap(); + assert_eq!(intent.state, IntentState::Published); + assert_eq!(intent.payload_hash, Some(payload_hash.clone())); + assert_eq!(intent.post_commit_data, Some(post_commit_data.clone())); + + store + .set_group_intent_committed(&mut conn, intent.id) + .unwrap(); + // Refresh from the DB + intent = conn.fetch(intent.id).unwrap().unwrap(); + assert_eq!(intent.state, IntentState::Committed); + // Make sure we haven't lost the payload hash + assert_eq!(intent.payload_hash, Some(payload_hash.clone())); + }) + } + + #[test] + fn test_republish_state_transition() { + let group_id = rand_vec(); + + with_store(|store| { + let mut conn = store.conn().unwrap(); + + insert_group(&mut conn, group_id.clone()); + + // Store the intent + NewGroupIntent::new(IntentKind::AddMembers, group_id.clone(), rand_vec()) + .store(&mut conn) + .unwrap(); + + let mut intent = find_first_intent(&mut conn, group_id.clone()); + + // Set to published + let payload_hash = rand_vec(); + let post_commit_data = rand_vec(); + store + .set_group_intent_published( + &mut conn, + intent.id, + payload_hash.clone(), + Some(post_commit_data.clone()), + ) + .unwrap(); + + intent = conn.fetch(intent.id).unwrap().unwrap(); + assert_eq!(intent.state, IntentState::Published); + assert_eq!(intent.payload_hash, Some(payload_hash.clone())); + + // Now revert back to ToPublish + store + .set_group_intent_to_publish(&mut conn, intent.id) + .unwrap(); + intent = conn.fetch(intent.id).unwrap().unwrap(); + assert_eq!(intent.state, IntentState::ToPublish); + assert!(intent.payload_hash.is_none()); + assert!(intent.post_commit_data.is_none()); + }) + } + + #[test] + fn test_invalid_state_transition() { + let group_id = rand_vec(); + + with_store(|store| { + let mut conn = store.conn().unwrap(); + + insert_group(&mut conn, group_id.clone()); + + // Store the intent + NewGroupIntent::new(IntentKind::AddMembers, group_id.clone(), rand_vec()) + .store(&mut conn) + .unwrap(); + + let intent = find_first_intent(&mut conn, group_id.clone()); + + let commit_result = store.set_group_intent_committed(&mut conn, intent.id); + assert!(commit_result.is_err()); + assert_eq!(commit_result.err().unwrap(), StorageError::NotFound); + + let to_publish_result = store.set_group_intent_to_publish(&mut conn, intent.id); + assert!(to_publish_result.is_err()); + assert_eq!(to_publish_result.err().unwrap(), StorageError::NotFound); + }) + } } diff --git a/xmtp_mls/src/storage/errors.rs b/xmtp_mls/src/storage/errors.rs index 13d431f43..fd5f1850d 100644 --- a/xmtp_mls/src/storage/errors.rs +++ b/xmtp_mls/src/storage/errors.rs @@ -16,4 +16,6 @@ pub enum StorageError { Serialization, #[error("deserialization error")] Deserialization, + #[error("not found")] + NotFound, }