Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into coda/stitching
Browse files Browse the repository at this point in the history
  • Loading branch information
codabrink committed Dec 20, 2024
2 parents 5477865 + d5a2f6d commit 14693de
Show file tree
Hide file tree
Showing 9 changed files with 163 additions and 126 deletions.
11 changes: 11 additions & 0 deletions xmtp_debug/src/args.rs
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,17 @@ pub enum EntityKind {
Identity,
}

impl std::fmt::Display for EntityKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
use EntityKind::*;
match self {
Group => write!(f, "group"),
Message => write!(f, "message"),
Identity => write!(f, "identity"),
}
}
}

/// specify the log output
#[derive(Args, Debug)]
pub struct LogOptions {
Expand Down
62 changes: 24 additions & 38 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ use crate::{
group_message::StoredGroupMessage,
refresh_state::EntityKind,
wallet_addresses::WalletEntry,
EncryptedMessageStore, StorageError,
EncryptedMessageStore, NotFound, StorageError,
},
subscriptions::{LocalEventError, LocalEvents},
types::InstallationId,
Expand Down Expand Up @@ -108,6 +108,12 @@ pub enum ClientError {
Generic(String),
}

impl From<NotFound> for ClientError {
fn from(value: NotFound) -> Self {
ClientError::Storage(StorageError::NotFound(value))
}
}

impl From<GroupError> for ClientError {
fn from(err: GroupError) -> ClientError {
ClientError::Group(Box::new(err))
Expand Down Expand Up @@ -309,11 +315,7 @@ where
address: String,
) -> Result<Option<String>, ClientError> {
let results = self.find_inbox_ids_from_addresses(conn, &[address]).await?;
if let Some(first_result) = results.into_iter().next() {
Ok(first_result)
} else {
Ok(None)
}
Ok(results.into_iter().next().flatten())
}

/// Calls the server to look up the `inbox_id`s` associated with a list of addresses.
Expand Down Expand Up @@ -556,10 +558,9 @@ where
{
Some(id) => id,
None => {
return Err(ClientError::Storage(StorageError::NotFound(format!(
"inbox id for address {} not found",
account_address
))))
return Err(ClientError::Storage(StorageError::NotFound(
NotFound::InboxIdForAddress(account_address),
)));
}
};

Expand Down Expand Up @@ -610,13 +611,10 @@ where
group_id: Vec<u8>,
) -> Result<MlsGroup<Self>, ClientError> {
let stored_group: Option<StoredGroup> = conn.fetch(&group_id)?;
match stored_group {
Some(group) => Ok(MlsGroup::new(self.clone(), group.id, group.created_at_ns)),
None => Err(ClientError::Storage(StorageError::NotFound(format!(
"group {}",
hex::encode(group_id)
)))),
}
stored_group
.map(|g| MlsGroup::new(self.clone(), g.id, g.created_at_ns))
.ok_or(NotFound::GroupById(group_id))
.map_err(Into::into)
}

/// Look up a group by its ID
Expand All @@ -638,34 +636,22 @@ where
target_inbox_id: String,
) -> Result<MlsGroup<Self>, ClientError> {
let conn = self.store().conn()?;
match conn.find_dm_group(&DmMembers {
member_one_inbox_id: self.inbox_id(),
member_two_inbox_id: &target_inbox_id,
})? {
Some(dm_group) => Ok(MlsGroup::new(
self.clone(),
dm_group.id,
dm_group.created_at_ns,
)),
None => Err(ClientError::Storage(StorageError::NotFound(format!(
"dm_target_inbox_id {}",
hex::encode(target_inbox_id)
)))),
}

let group = conn
.find_dm_group(&DmMembers {
member_one_inbox_id: self.inbox_id(),
member_two_inbox_id: &target_inbox_id,
})?
.ok_or(NotFound::DmByInbox(target_inbox_id))?;
Ok(MlsGroup::new(self.clone(), group.id, group.created_at_ns))
}

/// Look up a message by its ID
/// Returns a [`StoredGroupMessage`] if the message exists, or an error if it does not
pub fn message(&self, message_id: Vec<u8>) -> Result<StoredGroupMessage, ClientError> {
let conn = &mut self.store().conn()?;
let message = conn.get_group_message(&message_id)?;
match message {
Some(message) => Ok(message),
None => Err(ClientError::Storage(StorageError::NotFound(format!(
"message {}",
hex::encode(message_id)
)))),
}
Ok(message.ok_or(NotFound::MessageById(message_id))?)
}

/// Query for groups with optional filters
Expand Down
26 changes: 15 additions & 11 deletions xmtp_mls/src/groups/device_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@ use crate::{
configuration::NS_IN_HOUR,
storage::{
consent_record::StoredConsentRecord,
group::StoredGroup,
group::{ConversationType, GroupQueryArgs},
group_message::MsgQueryArgs,
group_message::{GroupMessageKind, StoredGroupMessage},
DbConnection, StorageError,
group::{ConversationType, GroupQueryArgs, StoredGroup},
group_message::{GroupMessageKind, MsgQueryArgs, StoredGroupMessage},
DbConnection, NotFound, StorageError,
},
subscriptions::{LocalEvents, StreamMessages, SubscribeError, SyncMessage},
xmtp_openmls_provider::XmtpOpenMlsProvider,
Expand Down Expand Up @@ -115,6 +113,12 @@ impl RetryableError for DeviceSyncError {
}
}

impl From<NotFound> for DeviceSyncError {
fn from(value: NotFound) -> Self {
DeviceSyncError::Storage(StorageError::NotFound(value))
}
}

impl<ApiClient, V> Client<ApiClient, V>
where
ApiClient: XmtpApi + Send + Sync + 'static,
Expand Down Expand Up @@ -211,9 +215,9 @@ where
retry,
(async {
conn.get_group_message(&message_id)?
.ok_or(DeviceSyncError::Storage(StorageError::NotFound(format!(
"Message id {message_id:?} not found."
))))
.ok_or(DeviceSyncError::from(NotFound::MessageById(
message_id.clone(),
)))
})
)?;

Expand All @@ -240,9 +244,9 @@ where
retry,
(async {
conn.get_group_message(&message_id)?
.ok_or(DeviceSyncError::Storage(StorageError::NotFound(format!(
"Message id {message_id:?} not found."
))))
.ok_or(DeviceSyncError::from(NotFound::MessageById(
message_id.clone(),
)))
})
)?;

Expand Down
6 changes: 4 additions & 2 deletions xmtp_mls/src/groups/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ use self::{
intents::IntentError,
validated_commit::CommitValidationError,
};
use crate::storage::{group::DmIdExt, group_message::ContentType, StorageError};
use crate::storage::{group::DmIdExt, group_message::ContentType, NotFound, StorageError};
use xmtp_common::time::now_ns;
use xmtp_proto::xmtp::mls::{
api::v1::{
Expand Down Expand Up @@ -418,7 +418,9 @@ impl<ScopedClient: ScopedGroupClient> MlsGroup<ScopedClient> {
let mls_group =
OpenMlsGroup::load(provider.storage(), &GroupId::from_slice(&self.group_id))
.map_err(crate::StorageError::from)?
.ok_or(crate::StorageError::NotFound("Group Not Found".into()))?;
.ok_or(StorageError::from(NotFound::GroupById(
self.group_id.to_vec(),
)))?;

// Perform the operation with the MLS group
operation(mls_group).await.map_err(Into::into)
Expand Down
14 changes: 7 additions & 7 deletions xmtp_mls/src/storage/encrypted_store/group.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,13 @@ use super::{
schema::groups::{self, dsl},
Sqlite,
};

use crate::{
groups::group_metadata::DmMembers, impl_fetch, impl_store, DuplicateItem, StorageError,
};

use crate::storage::NotFound;

use diesel::{
backend::Backend,
deserialize::{self, FromSql, FromSqlRow},
Expand Down Expand Up @@ -399,9 +403,8 @@ impl DbConnection {
Ok::<Option<i64>, StorageError>(ts)
})?;

last_ts.ok_or(StorageError::NotFound(format!(
"installation time for group {}",
hex::encode(group_id)
last_ts.ok_or(StorageError::NotFound(NotFound::InstallationTimeForGroup(
group_id,
)))
}

Expand All @@ -427,10 +430,7 @@ impl DbConnection {
Ok::<_, StorageError>(ts)
})?;

last_ts.ok_or(StorageError::NotFound(format!(
"installation time for group {}",
hex::encode(group_id)
)))
last_ts.ok_or(NotFound::InstallationTimeForGroup(group_id).into())
}

/// Updates the 'last time checked' we checked for new installations.
Expand Down
69 changes: 29 additions & 40 deletions xmtp_mls/src/storage/encrypted_store/group_intent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use super::{
use crate::{
groups::intents::{IntentError, SendMessageIntentData},
impl_fetch, impl_store,
storage::StorageError,
storage::{NotFound, StorageError},
utils::id::calculate_message_id,
Delete,
};
Expand Down Expand Up @@ -197,7 +197,7 @@ impl DbConnection {
staged_commit: Option<Vec<u8>>,
published_in_epoch: i64,
) -> Result<(), StorageError> {
let res = self.raw_query(|conn| {
let rows_changed = self.raw_query(|conn| {
diesel::update(dsl::group_intents)
.filter(dsl::id.eq(intent_id))
// State machine requires that the only valid state transition to Published is from
Expand All @@ -213,30 +213,25 @@ impl DbConnection {
.execute(conn)
})?;

match res {
// If nothing matched the query, check if its already published, otherwise return an error. Either ID or state was wrong
0 => {
let already_published = self.raw_query(|conn| {
dsl::group_intents
.filter(dsl::id.eq(intent_id))
.first::<StoredGroupIntent>(conn)
});

if already_published.is_ok() {
Ok(())
} else {
Err(StorageError::NotFound(format!(
"Published intent {intent_id} for commit"
)))
}
if rows_changed == 0 {
let already_published = self.raw_query(|conn| {
dsl::group_intents
.filter(dsl::id.eq(intent_id))
.first::<StoredGroupIntent>(conn)
});

if already_published.is_ok() {
return Ok(());
} else {
return Err(NotFound::IntentForToPublish(intent_id).into());
}
_ => Ok(()),
}
Ok(())
}

// Set the intent with the given ID to `Committed`
pub fn set_group_intent_committed(&self, intent_id: ID) -> Result<(), StorageError> {
let res = self.raw_query(|conn| {
let rows_changed = self.raw_query(|conn| {
diesel::update(dsl::group_intents)
.filter(dsl::id.eq(intent_id))
// State machine requires that the only valid state transition to Committed is from
Expand All @@ -246,19 +241,18 @@ impl DbConnection {
.execute(conn)
})?;

match res {
// If nothing matched the query, return an error. Either ID or state was wrong
0 => Err(StorageError::NotFound(format!(
"Published intent {intent_id} for commit"
))),
_ => Ok(()),
// If nothing matched the query, return an error. Either ID or state was wrong
if rows_changed == 0 {
return Err(NotFound::IntentForCommitted(intent_id).into());
}

Ok(())
}

// Set the intent with the given ID to `ToPublish`. Wipe any values for `payload_hash` and
// `post_commit_data`
pub fn set_group_intent_to_publish(&self, intent_id: ID) -> Result<(), StorageError> {
let res = self.raw_query(|conn| {
let rows_changed = self.raw_query(|conn| {
diesel::update(dsl::group_intents)
.filter(dsl::id.eq(intent_id))
// State machine requires that the only valid state transition to ToPublish is from
Expand All @@ -275,32 +269,27 @@ impl DbConnection {
.execute(conn)
})?;

match res {
// If nothing matched the query, return an error. Either ID or state was wrong
0 => Err(StorageError::NotFound(format!(
"Published intent {intent_id} for ToPublish"
))),
_ => Ok(()),
if rows_changed == 0 {
return Err(NotFound::IntentForPublish(intent_id).into());
}
Ok(())
}

/// Set the intent with the given ID to `Error`
#[tracing::instrument(level = "trace", skip(self))]
pub fn set_group_intent_error(&self, intent_id: ID) -> Result<(), StorageError> {
let res = self.raw_query(|conn| {
let rows_changed = self.raw_query(|conn| {
diesel::update(dsl::group_intents)
.filter(dsl::id.eq(intent_id))
.set(dsl::state.eq(IntentState::Error))
.execute(conn)
})?;

match res {
// If nothing matched the query, return an error. Either ID or state was wrong
0 => Err(StorageError::NotFound(format!(
"state for intent {intent_id}"
))),
_ => Ok(()),
if rows_changed == 0 {
return Err(NotFound::IntentById(intent_id).into());
}

Ok(())
}

// Simple lookup of intents by payload hash, meant to be used when processing messages off the
Expand Down
Loading

0 comments on commit 14693de

Please sign in to comment.