Skip to content

Commit

Permalink
make all Mutex's parking_lot, fix lint/compile errors
Browse files Browse the repository at this point in the history
  • Loading branch information
insipx committed Aug 9, 2024
1 parent c42a2c2 commit 2601f70
Show file tree
Hide file tree
Showing 11 changed files with 65 additions and 52 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions bindings_ffi/src/logger.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ pub trait FfiLogger: Send + Sync {
}

struct RustLogger {
logger: std::sync::Mutex<Box<dyn FfiLogger>>,
logger: parking_lot::Mutex<Box<dyn FfiLogger>>,
}

impl log::Log for RustLogger {
Expand All @@ -17,7 +17,7 @@ impl log::Log for RustLogger {
fn log(&self, record: &Record) {
if self.enabled(record.metadata()) {
// TODO handle errors
self.logger.lock().expect("Logger mutex is poisoned!").log(
self.logger.lock().log(
record.level() as u32,
record.level().to_string(),
format!("[libxmtp][t:{}] {}", thread_id::get(), record.args()),
Expand All @@ -33,7 +33,7 @@ pub fn init_logger(logger: Box<dyn FfiLogger>) {
// TODO handle errors
LOGGER_INIT.call_once(|| {
let logger = RustLogger {
logger: std::sync::Mutex::new(logger),
logger: parking_lot::Mutex::new(logger),
};
log::set_boxed_logger(Box::new(logger))
.map(|()| log::set_max_level(LevelFilter::Info))
Expand Down
1 change: 1 addition & 0 deletions xmtp_api_http/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ xmtp_proto = { path = "../xmtp_proto", features = ["proto_full"] }
bytes = "1.7"
tokio = { workspace = true, default-features = false, features = ["sync", "rt"] }
tokio-stream = { version = "0.1", default-features = false }
async-stream = "0.3"

[dev-dependencies]
tokio = { workspace = true, features = ["macros", "rt-multi-thread", "time"] }
Expand Down
23 changes: 7 additions & 16 deletions xmtp_api_http/src/util.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use futures::{stream::BoxStream, StreamExt};
use futures::stream::BoxStream;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use serde_json::Deserializer;
use std::io::Read;
Expand Down Expand Up @@ -48,10 +48,8 @@ pub async fn create_grpc_stream<
endpoint: String,
http_client: reqwest::Client,
) -> BoxStream<'static, Result<R, Error>> {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();

tokio::task::spawn(async move {
let mut bytes_stream = http_client
let stream = async_stream::stream! {
let bytes_stream = http_client
.post(endpoint)
.json(&request)
.send()
Expand All @@ -61,7 +59,7 @@ pub async fn create_grpc_stream<

log::debug!("Spawning grpc http stream");
let mut remaining = vec![];
while let Some(bytes) = bytes_stream.next().await {
for await bytes in bytes_stream {
let bytes = bytes
.map_err(|e| Error::new(ErrorKind::SubscriptionUpdateError).with(e.to_string()))?;

Expand All @@ -88,19 +86,12 @@ pub async fn create_grpc_stream<
Some(Ok(GrpcResponse::Empty {})) => continue 'messages,
None => break 'messages,
};
if tx.send(res).is_err() {
break 'messages;
}
}
// this will ensure the spawned task is dropped if the receiver stream is dropped
if tx.is_closed() {
break;
yield res;
}
}
Ok::<_, Error>(())
});
};

Box::pin(tokio_stream::wrappers::UnboundedReceiverStream::new(rx))
Box::pin(stream)
}

#[cfg(test)]
Expand Down
5 changes: 4 additions & 1 deletion xmtp_mls/benches/group_limit.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,16 @@ use xmtp_mls::{
bench::{create_identities_if_dont_exist, init_logging, Identity, BENCH_ROOT_SPAN},
test::TestClient,
},
Client,
};

pub type BenchClient = Client<TestClient>;

pub const IDENTITY_SAMPLES: [usize; 9] = [10, 20, 40, 80, 100, 200, 300, 400, 450];
pub const MAX_IDENTITIES: usize = 1_000;
pub const SAMPLE_SIZE: usize = 10;

fn setup() -> (Arc<TestClient>, Vec<Identity>, Runtime) {
fn setup() -> (Arc<BenchClient>, Vec<Identity>, Runtime) {
let runtime = Builder::new_multi_thread()
.enable_time()
.enable_io()
Expand Down
32 changes: 28 additions & 4 deletions xmtp_mls/src/groups/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,7 @@ mod tests {
}

#[tokio::test(flavor = "multi_thread", worker_threads = 10)]
#[ignore]
async fn test_subscribe_membership_changes() {
let amal = Arc::new(ClientBuilder::new_test_client(&generate_local_wallet()).await);
let bola = ClientBuilder::new_test_client(&generate_local_wallet()).await;
Expand All @@ -280,22 +281,45 @@ mod tests {
.create_group(None, GroupMetadataOptions::default())
.unwrap();

let mut stream = amal_group.stream(amal.clone()).await.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
let amal_ptr = amal.clone();
let amal_group_ptr = amal_group.clone();
let notify = Delivery::new(Some(Duration::from_secs(20)));
let notify_ptr = notify.clone();

let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let (start_tx, start_rx) = tokio::sync::oneshot::channel();
let mut stream = UnboundedReceiverStream::new(rx);
tokio::spawn(async move {
let mut stream = amal_group_ptr.stream(amal_ptr).await.unwrap();
let _ = start_tx.send(());
while let Some(item) = stream.next().await {
let _ = tx.send(item);
notify_ptr.notify_one();
}
});
// just to make sure stream is started
let _ = start_rx.await;

log::info!("ADDING AMAL TO GROUP");
amal_group
.add_members_by_inbox_id(&amal, vec![bola.inbox_id()])
.await
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;

notify
.wait_for_delivery()
.await
.expect("Never received group membership change from stream");
let first_val = stream.next().await.unwrap();
assert_eq!(first_val.kind, GroupMessageKind::MembershipChange);

amal_group
.send_message("hello".as_bytes(), &amal)
.await
.unwrap();
notify
.wait_for_delivery()
.await
.expect("Never received second message from stream");
let second_val = stream.next().await.unwrap();
assert_eq!(second_val.decrypted_message_bytes, "hello".as_bytes());
}
Expand Down
2 changes: 1 addition & 1 deletion xmtp_mls/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ where
#[cfg(test)]
impl<T> XmtpApi for T where T: XmtpMlsClient + XmtpIdentityClient + XmtpTestClient + ?Sized {}

#[cfg(test)]
#[cfg(any(test, feature = "test-utils", feature = "bench"))]
#[async_trait::async_trait]
pub trait XmtpTestClient {
async fn create_local() -> Self;
Expand Down
12 changes: 3 additions & 9 deletions xmtp_mls/src/storage/encrypted_store/db_connection.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use parking_lot::Mutex;
use std::fmt;
use std::sync::{Arc, Mutex};
use std::sync::Arc;

use crate::storage::RawDbConnection;

Expand Down Expand Up @@ -27,14 +28,7 @@ impl DbConnection {
where
F: FnOnce(&mut RawDbConnection) -> Result<T, diesel::result::Error>,
{
let mut lock = self.wrapped_conn.lock().unwrap_or_else(
|err| {
log::error!(
"Recovering from poisoned mutex - a thread has previously panicked holding this lock"
);
err.into_inner()
},
);
let mut lock = self.wrapped_conn.lock();
fun(&mut lock)
}
}
Expand Down
29 changes: 14 additions & 15 deletions xmtp_mls/src/subscriptions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@ use std::{collections::HashMap, pin::Pin, sync::Arc};

use futures::{FutureExt, Stream, StreamExt};
use prost::Message;
use tokio::{
sync::{mpsc, oneshot},
task::JoinHandle,
};
use tokio_stream::wrappers::{errors::BroadcastStreamRecvError, UnboundedReceiverStream};
use tokio::{sync::oneshot, task::JoinHandle};
use tokio_stream::wrappers::errors::BroadcastStreamRecvError;
use xmtp_proto::xmtp::mls::api::v1::WelcomeMessage;

use crate::{
Expand Down Expand Up @@ -339,6 +336,7 @@ where
}
};

log::debug!("switching streams");
// attempt to drain all ready messages from existing stream
while let Some(Some(message)) = messages_stream.next().now_or_never() {
extra_messages.push(message);
Expand Down Expand Up @@ -385,9 +383,10 @@ mod tests {
storage::group_message::StoredGroupMessage, Client,
};
use futures::StreamExt;
use parking_lot::Mutex;
use std::sync::{
atomic::{AtomicU64, Ordering},
Arc, Mutex,
Arc,
};
use xmtp_cryptography::utils::generate_local_wallet;

Expand Down Expand Up @@ -482,7 +481,7 @@ mod tests {
let mut handle = Client::<TestClient>::stream_all_messages_with_callback(
Arc::new(caro),
move |message| {
(*messages_clone.lock().unwrap()).push(message);
(*messages_clone.lock()).push(message);
notify_pointer.notify_one();
},
);
Expand Down Expand Up @@ -512,7 +511,7 @@ mod tests {
.unwrap();
notify.wait_for_delivery().await.unwrap();

let messages = messages.lock().unwrap();
let messages = messages.lock();
assert_eq!(messages[0].decrypted_message_bytes, b"first");
assert_eq!(messages[1].decrypted_message_bytes, b"second");
assert_eq!(messages[2].decrypted_message_bytes, b"third");
Expand Down Expand Up @@ -540,7 +539,7 @@ mod tests {
let mut handle =
Client::<TestClient>::stream_all_messages_with_callback(caro.clone(), move |message| {
delivery_pointer.notify_one();
(*messages_clone.lock().unwrap()).push(message);
(*messages_clone.lock()).push(message);
});
handle.wait_for_ready().await;

Expand Down Expand Up @@ -606,7 +605,7 @@ mod tests {
.expect("timed out waiting for `fifth`");

{
let messages = messages.lock().unwrap();
let messages = messages.lock();
assert_eq!(messages.len(), 5);
}

Expand All @@ -621,7 +620,7 @@ mod tests {
.unwrap();
tokio::time::sleep(std::time::Duration::from_millis(100)).await;

let messages = messages.lock().unwrap();
let messages = messages.lock();
assert_eq!(messages.len(), 5);
}

Expand All @@ -647,7 +646,7 @@ mod tests {
let blocked_pointer = blocked.clone();
let mut handle =
Client::<TestClient>::stream_all_messages_with_callback(caro.clone(), move |message| {
(*messages_clone.lock().unwrap()).push(message);
(*messages_clone.lock()).push(message);
blocked_pointer.fetch_sub(1, Ordering::SeqCst);
});
handle.wait_for_ready().await;
Expand Down Expand Up @@ -703,7 +702,7 @@ mod tests {

let closer =
Client::<TestClient>::stream_conversations_with_callback(alix.clone(), move |g| {
let mut groups = groups_pointer.lock().unwrap();
let mut groups = groups_pointer.lock();
groups.push(g);
notify_pointer.notify_one();
});
Expand All @@ -717,7 +716,7 @@ mod tests {
.expect("Stream never received group");

{
let grps = groups.lock().unwrap();
let grps = groups.lock();
assert_eq!(grps.len(), 1);
}

Expand All @@ -732,7 +731,7 @@ mod tests {
notify.wait_for_delivery().await.unwrap();

{
let grps = groups.lock().unwrap();
let grps = groups.lock();
assert_eq!(grps.len(), 2);
}

Expand Down
4 changes: 2 additions & 2 deletions xmtp_mls/src/utils/bench.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Utilities for xmtp_mls benchmarks
//! Utilities mostly include pre-generating identities in order to save time when writing/testing
//! benchmarks.
use crate::builder::ClientBuilder;
use crate::{builder::ClientBuilder, Client};
use ethers::signers::{LocalWallet, Signer};
use indicatif::{ProgressBar, ProgressStyle};
use once_cell::sync::OnceCell;
Expand Down Expand Up @@ -174,7 +174,7 @@ async fn create_identities(n: usize, is_dev_network: bool) -> Vec<Identity> {
/// node still has those identities.
pub async fn create_identities_if_dont_exist(
identities: usize,
client: &TestClient,
client: &Client<TestClient>,
is_dev_network: bool,
) -> Vec<Identity> {
match load_identities(is_dev_network) {
Expand Down
2 changes: 1 addition & 1 deletion xmtp_mls/src/utils/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ impl Delivery {
}
}

impl Client<GrpcClient> {
impl Client<TestClient> {
pub async fn is_registered(&self, address: &String) -> bool {
let ids = self
.api_client
Expand Down

0 comments on commit 2601f70

Please sign in to comment.