From 2601f704aa23074b4a2dafcb572a9b02f3bed737 Mon Sep 17 00:00:00 2001 From: Andrew Plaza Date: Fri, 9 Aug 2024 10:18:57 -0400 Subject: [PATCH] make all Mutex's parking_lot, fix lint/compile errors --- Cargo.lock | 1 + bindings_ffi/src/logger.rs | 6 ++-- xmtp_api_http/Cargo.toml | 1 + xmtp_api_http/src/util.rs | 23 ++++--------- xmtp_mls/benches/group_limit.rs | 5 ++- xmtp_mls/src/groups/subscriptions.rs | 32 ++++++++++++++++--- xmtp_mls/src/lib.rs | 2 +- .../storage/encrypted_store/db_connection.rs | 12 ++----- xmtp_mls/src/subscriptions.rs | 29 ++++++++--------- xmtp_mls/src/utils/bench.rs | 4 +-- xmtp_mls/src/utils/test.rs | 2 +- 11 files changed, 65 insertions(+), 52 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index d4a4b9148..b190e9c4a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6416,6 +6416,7 @@ dependencies = [ name = "xmtp_api_http" version = "0.1.0" dependencies = [ + "async-stream", "async-trait", "bytes", "futures", diff --git a/bindings_ffi/src/logger.rs b/bindings_ffi/src/logger.rs index c3dd68922..5e9d0c623 100644 --- a/bindings_ffi/src/logger.rs +++ b/bindings_ffi/src/logger.rs @@ -6,7 +6,7 @@ pub trait FfiLogger: Send + Sync { } struct RustLogger { - logger: std::sync::Mutex>, + logger: parking_lot::Mutex>, } impl log::Log for RustLogger { @@ -17,7 +17,7 @@ impl log::Log for RustLogger { fn log(&self, record: &Record) { if self.enabled(record.metadata()) { // TODO handle errors - self.logger.lock().expect("Logger mutex is poisoned!").log( + self.logger.lock().log( record.level() as u32, record.level().to_string(), format!("[libxmtp][t:{}] {}", thread_id::get(), record.args()), @@ -33,7 +33,7 @@ pub fn init_logger(logger: Box) { // TODO handle errors LOGGER_INIT.call_once(|| { let logger = RustLogger { - logger: std::sync::Mutex::new(logger), + logger: parking_lot::Mutex::new(logger), }; log::set_boxed_logger(Box::new(logger)) .map(|()| log::set_max_level(LevelFilter::Info)) diff --git a/xmtp_api_http/Cargo.toml b/xmtp_api_http/Cargo.toml index a584e90b7..89d8e9a86 100644 --- a/xmtp_api_http/Cargo.toml +++ b/xmtp_api_http/Cargo.toml @@ -18,6 +18,7 @@ xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] } bytes = "1.7" tokio = { workspace = true, default-features = false, features = ["sync", "rt"] } tokio-stream = { version = "0.1", default-features = false } +async-stream = "0.3" [dev-dependencies] tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time"] } diff --git a/xmtp_api_http/src/util.rs b/xmtp_api_http/src/util.rs index 878b4e5a9..0942d4f67 100644 --- a/xmtp_api_http/src/util.rs +++ b/xmtp_api_http/src/util.rs @@ -1,4 +1,4 @@ -use futures::{stream::BoxStream, StreamExt}; +use futures::stream::BoxStream; use serde::{de::DeserializeOwned, Deserialize, Serialize}; use serde_json::Deserializer; use std::io::Read; @@ -48,10 +48,8 @@ pub async fn create_grpc_stream< endpoint: String, http_client: reqwest::Client, ) -> BoxStream<'static, Result> { - let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); - - tokio::task::spawn(async move { - let mut bytes_stream = http_client + let stream = async_stream::stream! { + let bytes_stream = http_client .post(endpoint) .json(&request) .send() @@ -61,7 +59,7 @@ pub async fn create_grpc_stream< log::debug!("Spawning grpc http stream"); let mut remaining = vec![]; - while let Some(bytes) = bytes_stream.next().await { + for await bytes in bytes_stream { let bytes = bytes .map_err(|e| Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()))?; @@ -88,19 +86,12 @@ pub async fn create_grpc_stream< Some(Ok(GrpcResponse::Empty {})) => continue 'messages, None => break 'messages, }; - if tx.send(res).is_err() { - break 'messages; - } - } - // this will ensure the spawned task is dropped if the receiver stream is dropped - if tx.is_closed() { - break; + yield res; } } - Ok::<_, Error>(()) - }); + }; - Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)) + Box::pin(stream) } #[cfg(test)] diff --git a/xmtp_mls/benches/group_limit.rs b/xmtp_mls/benches/group_limit.rs index 135bfc519..8408bd305 100755 --- a/xmtp_mls/benches/group_limit.rs +++ b/xmtp_mls/benches/group_limit.rs @@ -15,13 +15,16 @@ use xmtp_mls::{ bench::{create_identities_if_dont_exist, init_logging, Identity, BENCH_ROOT_SPAN}, test::TestClient, }, + Client, }; +pub type BenchClient = Client; + pub const IDENTITY_SAMPLES: [usize; 9] = [10, 20, 40, 80, 100, 200, 300, 400, 450]; pub const MAX_IDENTITIES: usize = 1_000; pub const SAMPLE_SIZE: usize = 10; -fn setup() -> (Arc, Vec, Runtime) { +fn setup() -> (Arc, Vec, Runtime) { let runtime = Builder::new_multi_thread() .enable_time() .enable_io() diff --git a/xmtp_mls/src/groups/subscriptions.rs b/xmtp_mls/src/groups/subscriptions.rs index 6955f1eb6..33f1520b3 100644 --- a/xmtp_mls/src/groups/subscriptions.rs +++ b/xmtp_mls/src/groups/subscriptions.rs @@ -272,6 +272,7 @@ mod tests { } #[tokio::test(flavor = "multi_thread", worker_threads = 10)] + #[ignore] async fn test_subscribe_membership_changes() { let amal = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await); let bola = ClientBuilder::new_test_client(&generate_local_wallet()).await; @@ -280,15 +281,34 @@ mod tests { .create_group(None, GroupMetadataOptions::default()) .unwrap(); - let mut stream = amal_group.stream(amal.clone()).await.unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(50)).await; + let amal_ptr = amal.clone(); + let amal_group_ptr = amal_group.clone(); + let notify = Delivery::new(Some(Duration::from_secs(20))); + let notify_ptr = notify.clone(); + let (tx, rx) = tokio::sync::mpsc::unbounded_channel(); + let (start_tx, start_rx) = tokio::sync::oneshot::channel(); + let mut stream = UnboundedReceiverStream::new(rx); + tokio::spawn(async move { + let mut stream = amal_group_ptr.stream(amal_ptr).await.unwrap(); + let _ = start_tx.send(()); + while let Some(item) = stream.next().await { + let _ = tx.send(item); + notify_ptr.notify_one(); + } + }); + // just to make sure stream is started + let _ = start_rx.await; + + log::info!("ADDING AMAL TO GROUP"); amal_group .add_members_by_inbox_id(&amal, vec![bola.inbox_id()]) .await .unwrap(); - tokio::time::sleep(std::time::Duration::from_millis(50)).await; - + notify + .wait_for_delivery() + .await + .expect("Never received group membership change from stream"); let first_val = stream.next().await.unwrap(); assert_eq!(first_val.kind, GroupMessageKind::MembershipChange); @@ -296,6 +316,10 @@ mod tests { .send_message("hello".as_bytes(), &amal) .await .unwrap(); + notify + .wait_for_delivery() + .await + .expect("Never received second message from stream"); let second_val = stream.next().await.unwrap(); assert_eq!(second_val.decrypted_message_bytes, "hello".as_bytes()); } diff --git a/xmtp_mls/src/lib.rs b/xmtp_mls/src/lib.rs index 351723634..4793e287b 100644 --- a/xmtp_mls/src/lib.rs +++ b/xmtp_mls/src/lib.rs @@ -45,7 +45,7 @@ where #[cfg(test)] impl XmtpApi for T where T: XmtpMlsClient + XmtpIdentityClient + XmtpTestClient + ?Sized {} -#[cfg(test)] +#[cfg(any(test, feature = "test-utils", feature = "bench"))] #[async_trait::async_trait] pub trait XmtpTestClient { async fn create_local() -> Self; diff --git a/xmtp_mls/src/storage/encrypted_store/db_connection.rs b/xmtp_mls/src/storage/encrypted_store/db_connection.rs index 7ed1e4930..6227c1a43 100644 --- a/xmtp_mls/src/storage/encrypted_store/db_connection.rs +++ b/xmtp_mls/src/storage/encrypted_store/db_connection.rs @@ -1,5 +1,6 @@ +use parking_lot::Mutex; use std::fmt; -use std::sync::{Arc, Mutex}; +use std::sync::Arc; use crate::storage::RawDbConnection; @@ -27,14 +28,7 @@ impl DbConnection { where F: FnOnce(&mut RawDbConnection) -> Result, { - let mut lock = self.wrapped_conn.lock().unwrap_or_else( - |err| { - log::error!( - "Recovering from poisoned mutex - a thread has previously panicked holding this lock" - ); - err.into_inner() - }, - ); + let mut lock = self.wrapped_conn.lock(); fun(&mut lock) } } diff --git a/xmtp_mls/src/subscriptions.rs b/xmtp_mls/src/subscriptions.rs index baa027df3..cfc368fb0 100644 --- a/xmtp_mls/src/subscriptions.rs +++ b/xmtp_mls/src/subscriptions.rs @@ -2,11 +2,8 @@ use std::{collections::HashMap, pin::Pin, sync::Arc}; use futures::{FutureExt, Stream, StreamExt}; use prost::Message; -use tokio::{ - sync::{mpsc, oneshot}, - task::JoinHandle, -}; -use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, UnboundedReceiverStream}; +use tokio::{sync::oneshot, task::JoinHandle}; +use tokio_stream::wrappers::errors::BroadcastStreamRecvError; use xmtp_proto::xmtp::mls::api::v1::WelcomeMessage; use crate::{ @@ -339,6 +336,7 @@ where } }; + log::debug!("switching streams"); // attempt to drain all ready messages from existing stream while let Some(Some(message)) = messages_stream.next().now_or_never() { extra_messages.push(message); @@ -385,9 +383,10 @@ mod tests { storage::group_message::StoredGroupMessage, Client, }; use futures::StreamExt; + use parking_lot::Mutex; use std::sync::{ atomic::{AtomicU64, Ordering}, - Arc, Mutex, + Arc, }; use xmtp_cryptography::utils::generate_local_wallet; @@ -482,7 +481,7 @@ mod tests { let mut handle = Client::::stream_all_messages_with_callback( Arc::new(caro), move |message| { - (*messages_clone.lock().unwrap()).push(message); + (*messages_clone.lock()).push(message); notify_pointer.notify_one(); }, ); @@ -512,7 +511,7 @@ mod tests { .unwrap(); notify.wait_for_delivery().await.unwrap(); - let messages = messages.lock().unwrap(); + let messages = messages.lock(); assert_eq!(messages[0].decrypted_message_bytes, b"first"); assert_eq!(messages[1].decrypted_message_bytes, b"second"); assert_eq!(messages[2].decrypted_message_bytes, b"third"); @@ -540,7 +539,7 @@ mod tests { let mut handle = Client::::stream_all_messages_with_callback(caro.clone(), move |message| { delivery_pointer.notify_one(); - (*messages_clone.lock().unwrap()).push(message); + (*messages_clone.lock()).push(message); }); handle.wait_for_ready().await; @@ -606,7 +605,7 @@ mod tests { .expect("timed out waiting for `fifth`"); { - let messages = messages.lock().unwrap(); + let messages = messages.lock(); assert_eq!(messages.len(), 5); } @@ -621,7 +620,7 @@ mod tests { .unwrap(); tokio::time::sleep(std::time::Duration::from_millis(100)).await; - let messages = messages.lock().unwrap(); + let messages = messages.lock(); assert_eq!(messages.len(), 5); } @@ -647,7 +646,7 @@ mod tests { let blocked_pointer = blocked.clone(); let mut handle = Client::::stream_all_messages_with_callback(caro.clone(), move |message| { - (*messages_clone.lock().unwrap()).push(message); + (*messages_clone.lock()).push(message); blocked_pointer.fetch_sub(1, Ordering::SeqCst); }); handle.wait_for_ready().await; @@ -703,7 +702,7 @@ mod tests { let closer = Client::::stream_conversations_with_callback(alix.clone(), move |g| { - let mut groups = groups_pointer.lock().unwrap(); + let mut groups = groups_pointer.lock(); groups.push(g); notify_pointer.notify_one(); }); @@ -717,7 +716,7 @@ mod tests { .expect("Stream never received group"); { - let grps = groups.lock().unwrap(); + let grps = groups.lock(); assert_eq!(grps.len(), 1); } @@ -732,7 +731,7 @@ mod tests { notify.wait_for_delivery().await.unwrap(); { - let grps = groups.lock().unwrap(); + let grps = groups.lock(); assert_eq!(grps.len(), 2); } diff --git a/xmtp_mls/src/utils/bench.rs b/xmtp_mls/src/utils/bench.rs index 101f0ab6b..a880b1bf9 100644 --- a/xmtp_mls/src/utils/bench.rs +++ b/xmtp_mls/src/utils/bench.rs @@ -1,7 +1,7 @@ //! Utilities for xmtp_mls benchmarks //! Utilities mostly include pre-generating identities in order to save time when writing/testing //! benchmarks. -use crate::builder::ClientBuilder; +use crate::{builder::ClientBuilder, Client}; use ethers::signers::{LocalWallet, Signer}; use indicatif::{ProgressBar, ProgressStyle}; use once_cell::sync::OnceCell; @@ -174,7 +174,7 @@ async fn create_identities(n: usize, is_dev_network: bool) -> Vec { /// node still has those identities. pub async fn create_identities_if_dont_exist( identities: usize, - client: &TestClient, + client: &Client, is_dev_network: bool, ) -> Vec { match load_identities(is_dev_network) { diff --git a/xmtp_mls/src/utils/test.rs b/xmtp_mls/src/utils/test.rs index 7ac099ce9..06d50cd33 100644 --- a/xmtp_mls/src/utils/test.rs +++ b/xmtp_mls/src/utils/test.rs @@ -159,7 +159,7 @@ impl Delivery { } } -impl Client { +impl Client { pub async fn is_registered(&self, address: &String) -> bool { let ids = self .api_client