From a215e6f1159555ce24e3ee212d1911310dffa408 Mon Sep 17 00:00:00 2001 From: Evan Rittenhouse Date: Sat, 13 Jul 2024 18:41:29 -0500 Subject: [PATCH] sync: add Sender::closed future --- tokio/src/loom/std/mutex.rs | 2 +- tokio/src/sync/broadcast.rs | 59 +++++++++++++++++++++++++++++++++-- tokio/tests/sync_broadcast.rs | 17 ++++++++++ 3 files changed, 75 insertions(+), 3 deletions(-) diff --git a/tokio/src/loom/std/mutex.rs b/tokio/src/loom/std/mutex.rs index 7b8f9ba1e24..95f6d73ba60 100644 --- a/tokio/src/loom/std/mutex.rs +++ b/tokio/src/loom/std/mutex.rs @@ -1,7 +1,7 @@ use std::sync::{self, MutexGuard, TryLockError}; /// Adapter for `std::Mutex` that removes the poisoning aspects -/// from its api. +/// from its API. #[derive(Debug)] pub(crate) struct Mutex(sync::Mutex); diff --git a/tokio/src/sync/broadcast.rs b/tokio/src/sync/broadcast.rs index ba0a44fb8b9..094161bc706 100644 --- a/tokio/src/sync/broadcast.rs +++ b/tokio/src/sync/broadcast.rs @@ -116,6 +116,7 @@ //! } //! ``` +use crate::future::poll_fn; use crate::loom::cell::UnsafeCell; use crate::loom::sync::atomic::{AtomicBool, AtomicUsize}; use crate::loom::sync::{Arc, Mutex, MutexGuard, RwLock, RwLockReadGuard}; @@ -163,6 +164,7 @@ use std::task::{Context, Poll, Waker}; /// [`broadcast`]: crate::sync::broadcast pub struct Sender { shared: Arc>, + notify_rx_closed: Arc, } /// Receiving-half of the [`broadcast`] channel. @@ -300,6 +302,8 @@ pub mod error { use self::error::{RecvError, SendError, TryRecvError}; +use super::Notify; + /// Data shared between senders and receivers. struct Shared { /// slots in the channel. @@ -313,6 +317,9 @@ struct Shared { /// Number of outstanding Sender handles. num_tx: AtomicUsize, + + /// Notify when a subscribed [`Receiver`] is dropped. + notify_rx_drop: Notify, } /// Next position to write a value. @@ -527,9 +534,15 @@ impl Sender { waiters: LinkedList::new(), }), num_tx: AtomicUsize::new(1), + notify_rx_drop: Notify::new(), }); - Sender { shared } + let notify_rx_closed = Arc::new(Notify::new()); + + Sender { + shared, + notify_rx_closed, + } } /// Attempts to send a value to all active [`Receiver`] handles, returning @@ -804,6 +817,38 @@ impl Sender { Arc::ptr_eq(&self.shared, &other.shared) } + /// A future which completes when the number of [Receiver]s subscribed to this `Sender` reaches + /// zero. + /// + /// # Examples + /// + /// ``` + /// use futures::FutureExt; + /// use tokio::sync::broadcast; + /// + /// #[tokio::main] + /// async fn main() { + /// let (tx, mut rx1) = broadcast::channel::(16); + /// let mut rx2 = tx.subscribe(); + /// + /// tokio::spawn(async move { + /// assert_eq!(rx1.recv().await.unwrap(), 10); + /// }); + /// + /// let _ = tx.send(10); + /// assert!(tx.closed().now_or_never().is_none()); + /// + /// let _ = tokio::spawn(async move { + /// assert_eq!(rx2.recv().await.unwrap(), 10); + /// }).await; + /// + /// assert!(tx.closed().now_or_never().is_some()); + /// } + /// ``` + pub async fn closed(&self) { + self.shared.notify_rx_drop.notified().await; + } + fn close_channel(&self) { let mut tail = self.shared.tail.lock(); tail.closed = true; @@ -946,7 +991,12 @@ impl Clone for Sender { let shared = self.shared.clone(); shared.num_tx.fetch_add(1, SeqCst); - Sender { shared } + let notify_rx_closed = Arc::clone(&self.notify_rx_closed); + + Sender { + shared, + notify_rx_closed, + } } } @@ -1346,9 +1396,14 @@ impl Drop for Receiver { tail.rx_cnt -= 1; let until = tail.pos; + let remaining_rx = tail.rx_cnt; drop(tail); + if remaining_rx == 0 { + self.shared.notify_rx_drop.notify_waiters(); + } + while self.next < until { match self.recv_ref(None) { Ok(_) => {} diff --git a/tokio/tests/sync_broadcast.rs b/tokio/tests/sync_broadcast.rs index 2638c1f33d4..f397782206e 100644 --- a/tokio/tests/sync_broadcast.rs +++ b/tokio/tests/sync_broadcast.rs @@ -640,3 +640,20 @@ fn send_in_waker_drop() { // Shouldn't deadlock. let _ = tx.send(()); } + +#[test] +fn broadcast_sender_closed() { + let (tx, rx) = broadcast::channel::<()>(1); + let rx2 = tx.subscribe(); + + let mut task = task::spawn(tx.closed()); + assert_pending!(task.poll()); + + drop(rx); + assert!(!task.is_woken()); + assert_pending!(task.poll()); + + drop(rx2); + assert!(task.is_woken()); + assert_ready!(task.poll()); +}