diff --git a/Cargo.toml b/Cargo.toml index a58a375..c52f985 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "stubborn-io" -version = "0.3.4" +version = "0.3.5" authors = ["David Raifaizen "] edition = "2021" description = "io traits/structs that automatically recover from potential disconnections/interruptions." @@ -9,7 +9,7 @@ keywords = ["reconnect", "retry", "stubborn", "io", "StubbornTcpStream"] repository = "https://github.com/craftytrickster/stubborn-io" documentation = "https://docs.rs/stubborn-io" readme = "README.md" - + [dependencies] tokio = { version = "1", features = ["time", "net"] } rand = "0.8" diff --git a/src/tokio/io.rs b/src/tokio/io.rs index b1a6705..cdbccba 100644 --- a/src/tokio/io.rs +++ b/src/tokio/io.rs @@ -5,6 +5,7 @@ use std::io::{self, ErrorKind, IoSlice}; use std::marker::PhantomData; use std::ops::{Deref, DerefMut}; use std::pin::Pin; +use std::sync::{Arc, Mutex}; use std::task::{Context, Poll}; use std::time::Duration; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; @@ -53,12 +54,13 @@ where struct AttemptsTracker { attempt_num: usize, - retries_remaining: Box + Send>, + retries_remaining: Box + Send + Sync>, } struct ReconnectStatus { attempts_tracker: AttemptsTracker, - reconnect_attempt: Pin> + Send>>, + #[allow(clippy::type_complexity)] + reconnect_attempt: Arc> + Send>>>>, _phantom_data: PhantomData, } @@ -73,7 +75,9 @@ where attempt_num: 0, retries_remaining: (options.retries_to_attempt_fn)(), }, - reconnect_attempt: Box::pin(async { unreachable!("Not going to happen") }), + reconnect_attempt: Arc::new(Mutex::new(Box::pin(async { + unreachable!("Not going to happen") + }))), _phantom_data: PhantomData, } } @@ -244,7 +248,7 @@ where T::establish(ctor_arg).await }; - reconnect_status.reconnect_attempt = Box::pin(reconnect_attempt); + reconnect_status.reconnect_attempt = Arc::new(Mutex::new(Box::pin(reconnect_attempt))); info!( "Will perform reconnect attempt #{} in {:?}.", @@ -256,16 +260,18 @@ where } fn poll_disconnect(mut self: Pin<&mut Self>, cx: &mut Context) { - let (attempt, attempt_num) = match &mut self.status { + let (attempt, attempt_num) = match self.status { Status::Connected => unreachable!(), Status::Disconnected(ref mut status) => ( - Pin::new(&mut status.reconnect_attempt), + status.reconnect_attempt.clone(), status.attempts_tracker.attempt_num, ), Status::FailedAndExhausted => unreachable!(), }; - match attempt.poll(cx) { + let mut attempt = attempt.lock().unwrap(); + + match attempt.as_mut().poll(cx) { Poll::Ready(Ok(underlying_io)) => { info!("Connection re-established"); cx.waker().wake_by_ref(); diff --git a/tests/dummy_tests.rs b/tests/dummy_tests.rs index 260df34..68fce3c 100644 --- a/tests/dummy_tests.rs +++ b/tests/dummy_tests.rs @@ -290,3 +290,25 @@ mod already_connected { assert!(msg.unwrap().is_err()); } } + +#[tokio::test] +async fn test_that_works_with_sync() { + fn make_framed(_stream: T) + where + T: AsyncRead + AsyncWrite + Send + Sync + 'static, + { + let _ = _stream; + } + + let options = ReconnectOptions::new(); + let connect_outcomes = Arc::new(Mutex::new(vec![true])); + let ctor = DummyCtor { + connect_outcomes, + ..DummyCtor::default() + }; + let dummy = StubbornDummy::connect_with_options(ctor, options) + .await + .unwrap(); + + make_framed(dummy); +}