Skip to content

Commit

Permalink
Merge branch 'main' into mc/thread-safe-groups
Browse files Browse the repository at this point in the history
  • Loading branch information
mchenani authored Dec 13, 2024
2 parents d288c6b + 7c7dbdb commit d455367
Show file tree
Hide file tree
Showing 12 changed files with 704 additions and 495 deletions.
11 changes: 5 additions & 6 deletions bindings_wasm/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,25 +7,24 @@ version.workspace = true
crate-type = ["cdylib", "rlib"]

[dependencies]
console_error_panic_hook.workspace = true
hex.workspace = true
js-sys.workspace = true
prost.workspace = true
serde-wasm-bindgen = "0.6.5"
serde.workspace = true
tokio.workspace = true
tracing-subscriber = { workspace = true, features = ["env-filter", "json"] }
tracing-web = "0.1"
tracing.workspace = true
wasm-bindgen-futures.workspace = true
wasm-bindgen.workspace = true
xmtp_api_http = { path = "../xmtp_api_http" }
xmtp_common.workspace = true
xmtp_cryptography = { path = "../xmtp_cryptography" }
xmtp_id = { path = "../xmtp_id" }
xmtp_mls = { path = "../xmtp_mls", features = ["test-utils", "http-api"] }
xmtp_common.workspace = true
xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] }
tracing-web = "0.1"
tracing.workspace = true
tracing-subscriber = { workspace = true, features = ["env-filter", "json"] }
console_error_panic_hook.workspace = true

[dev-dependencies]
wasm-bindgen-test.workspace = true
xmtp_mls = { path = "../xmtp_mls", features = ["test-utils", "http-api"] }
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
DROP TABLE user_preferences;

CREATE TABLE user_preferences(
id INTEGER PRIMARY KEY ASC,
hmac_key BLOB NOT NULL
);
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
DROP TABLE user_preferences;

CREATE TABLE user_preferences(
id INTEGER PRIMARY KEY ASC NOT NULL,
hmac_key BLOB
);
11 changes: 11 additions & 0 deletions xmtp_mls/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ use xmtp_proto::xmtp::mls::api::v1::{
GroupMessage, WelcomeMessage,
};

#[cfg(any(test, feature = "test-utils"))]
use crate::groups::device_sync::WorkerHandle;

use crate::{
api::ApiClientWrapper,
groups::{
Expand Down Expand Up @@ -144,6 +147,9 @@ pub struct Client<ApiClient, V = RemoteSignatureVerifier<ApiClient>> {
pub(crate) local_events: broadcast::Sender<LocalEvents<Self>>,
/// The method of verifying smart contract wallet signatures for this Client
pub(crate) scw_verifier: Arc<V>,

#[cfg(any(test, feature = "test-utils"))]
pub(crate) sync_worker_handle: Arc<parking_lot::Mutex<Option<Arc<WorkerHandle>>>>,
}

// most of these things are `Arc`'s
Expand All @@ -155,6 +161,9 @@ impl<ApiClient, V> Clone for Client<ApiClient, V> {
history_sync_url: self.history_sync_url.clone(),
local_events: self.local_events.clone(),
scw_verifier: self.scw_verifier.clone(),

#[cfg(any(test, feature = "test-utils"))]
sync_worker_handle: self.sync_worker_handle.clone(),
}
}
}
Expand Down Expand Up @@ -240,6 +249,8 @@ where
context,
history_sync_url,
local_events: tx,
#[cfg(any(test, feature = "test-utils"))]
sync_worker_handle: Arc::new(parking_lot::Mutex::default()),
scw_verifier: scw_verifier.into(),
}
}
Expand Down
129 changes: 97 additions & 32 deletions xmtp_mls/src/groups/device_sync.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@ use futures::{Stream, StreamExt};
use preference_sync::UserPreferenceUpdate;
use rand::{Rng, RngCore};
use serde::{Deserialize, Serialize};
use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use thiserror::Error;
use tokio::sync::OnceCell;
use tokio::sync::{Notify, OnceCell};
use tokio::time::error::Elapsed;
use tokio::time::timeout;
use tracing::{instrument, warn};
use xmtp_common::time::{now_ns, Duration};
use xmtp_common::{retry_async, Retry, RetryableError};
Expand Down Expand Up @@ -104,8 +109,8 @@ pub enum DeviceSyncError {
SyncPayloadTooOld,
#[error(transparent)]
Subscribe(#[from] SubscribeError),
#[error("Unable to serialize: {0}")]
Bincode(String),
#[error(transparent)]
Bincode(#[from] bincode::Error),
}

impl RetryableError for DeviceSyncError {
Expand All @@ -114,6 +119,17 @@ impl RetryableError for DeviceSyncError {
}
}

#[cfg(any(test, feature = "test-utils"))]
impl<ApiClient, V> Client<ApiClient, V> {
pub fn sync_worker_handle(&self) -> Option<Arc<WorkerHandle>> {
self.sync_worker_handle.lock().clone()
}

pub(crate) fn set_sync_worker_handle(&self, handle: Arc<WorkerHandle>) {
*self.sync_worker_handle.lock() = Some(handle);
}
}

impl<ApiClient, V> Client<ApiClient, V>
where
ApiClient: XmtpApi + Send + Sync + 'static,
Expand All @@ -128,7 +144,10 @@ where
"starting sync worker"
);

SyncWorker::new(client).spawn_worker();
let worker = SyncWorker::new(client);
#[cfg(any(test, feature = "test-utils"))]
self.set_sync_worker_handle(worker.handle.clone());
worker.spawn_worker();
}
}

Expand All @@ -141,6 +160,57 @@ pub struct SyncWorker<ApiClient, V> {
>,
init: OnceCell<()>,
retry: Retry,

// Number of events processed
#[cfg(any(test, feature = "test-utils"))]
handle: Arc<WorkerHandle>,
}

#[cfg(any(test, feature = "test-utils"))]
pub struct WorkerHandle {
processed: AtomicUsize,
notify: Notify,
}

#[cfg(any(test, feature = "test-utils"))]
impl WorkerHandle {
pub async fn wait_for_new_events(&self, mut count: usize) -> Result<(), Elapsed> {
timeout(Duration::from_secs(3), async {
while count > 0 {
self.notify.notified().await;
count -= 1;
}
})
.await?;

Ok(())
}

pub async fn wait_for_processed_count(&self, expected: usize) -> Result<(), Elapsed> {
timeout(Duration::from_secs(3), async {
while self.processed.load(Ordering::SeqCst) < expected {
self.notify.notified().await;
}
})
.await?;

Ok(())
}

pub async fn block_for_num_events<Fut>(&self, num_events: usize, op: Fut) -> Result<(), Elapsed>
where
Fut: Future<Output = ()>,
{
let processed_count = self.processed_count();
op.await;
self.wait_for_processed_count(processed_count + num_events)
.await?;
Ok(())
}

pub fn processed_count(&self) -> usize {
self.processed.load(Ordering::SeqCst)
}
}

impl<ApiClient, V> SyncWorker<ApiClient, V>
Expand Down Expand Up @@ -168,33 +238,22 @@ where
self.on_request(message_id, &provider).await?
}
},
LocalEvents::OutgoingPreferenceUpdates(consent_records) => {
let provider = self.client.mls_provider()?;
for record in consent_records {
let UserPreferenceUpdate::ConsentUpdate(consent_record) = record else {
continue;
};

self.client
.send_consent_update(&provider, consent_record)
.await?;
}
LocalEvents::OutgoingPreferenceUpdates(preference_updates) => {
tracing::error!("Outgoing preference update {preference_updates:?}");
UserPreferenceUpdate::sync_across_devices(preference_updates, &self.client)
.await?;
}
LocalEvents::IncomingPreferenceUpdate(updates) => {
let provider = self.client.mls_provider()?;
let consent_records = updates
.into_iter()
.filter_map(|pu| match pu {
UserPreferenceUpdate::ConsentUpdate(cr) => Some(cr),
_ => None,
})
.collect::<Vec<_>>();
provider
.conn_ref()
.insert_or_replace_consent_records(&consent_records)?;
LocalEvents::IncomingPreferenceUpdate(_) => {
tracing::error!("Incoming preference update");
}
_ => {}
}

#[cfg(any(test, feature = "test-utils"))]
{
self.handle.processed.fetch_add(1, Ordering::SeqCst);
self.handle.notify.notify_waiters();
}
}
Ok(())
}
Expand Down Expand Up @@ -319,6 +378,12 @@ where
stream,
init: OnceCell::new(),
retry,

#[cfg(any(test, feature = "test-utils"))]
handle: Arc::new(WorkerHandle {
processed: AtomicUsize::new(0),
notify: Notify::new(),
}),
}
}

Expand Down Expand Up @@ -404,10 +469,10 @@ where

let _message_id = sync_group.prepare_message(&content_bytes, provider, {
let request = request.clone();
move |_time_ns| PlaintextEnvelope {
move |now| PlaintextEnvelope {
content: Some(Content::V2(V2 {
message_type: Some(MessageType::DeviceSyncRequest(request)),
idempotency_key: new_request_id(),
idempotency_key: now.to_string(),
})),
}
})?;
Expand Down Expand Up @@ -471,14 +536,14 @@ where
(content_bytes, contents)
};

sync_group.prepare_message(&content_bytes, provider, |_time_ns| PlaintextEnvelope {
sync_group.prepare_message(&content_bytes, provider, |now| PlaintextEnvelope {
content: Some(Content::V2(V2 {
idempotency_key: new_request_id(),
message_type: Some(MessageType::DeviceSyncReply(contents)),
idempotency_key: now.to_string(),
})),
})?;

sync_group.sync_until_last_intent_resolved(provider).await?;
sync_group.publish_intents(provider).await?;

Ok(())
}
Expand Down
Loading

0 comments on commit d455367

Please sign in to comment.