diff --git a/Cargo.toml b/Cargo.toml index 7102210..5cbec97 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,9 @@ license = "MIT" repository = "https://github.com/oxidecomputer/async-bb8-diesel" keywords = ["diesel", "r2d2", "pool", "tokio", "async"] +[features] +use_has_broken_as_valid_check = [] + [dependencies] bb8 = "0.8" async-trait = "0.1.73" diff --git a/src/connection_manager.rs b/src/connection_manager.rs index 5da755a..2d50029 100644 --- a/src/connection_manager.rs +++ b/src/connection_manager.rs @@ -57,6 +57,17 @@ impl ConnectionManager { // Intentionally panic if the inner closure panics. .unwrap() } + + #[cfg(feature = "use_has_broken_as_valid_check")] + fn run(&self, f: F) -> R + where + R: Send + 'static, + F: Send + 'static + FnOnce(&r2d2::ConnectionManager) -> R, + { + let cloned = self.inner.clone(); + let cloned = cloned.lock().unwrap(); + f(&*cloned) + } } #[async_trait] @@ -76,11 +87,17 @@ where async fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { let c = Connection(conn.0.clone()); - self.run_blocking(move |m| { - m.is_valid(&mut *c.inner())?; - Ok(()) - }) - .await + + #[cfg(not(feature = "use_has_broken_as_valid_check"))] + { + self.run_blocking(move |m| closure_for_is_valid_of_manager(m, c)) + .await + } + + #[cfg(feature = "use_has_broken_as_valid_check")] + { + self.run(move |m| closure_for_is_valid_of_manager(m, c)) + } } fn has_broken(&self, _: &mut Self::Connection) -> bool { @@ -90,3 +107,33 @@ where false } } + +#[cfg(feature = "use_has_broken_as_valid_check")] +fn closure_for_is_valid_of_manager( + m: &r2d2::ConnectionManager, + conn: Connection, +) -> Result<(), ConnectionError> +where + T: R2D2Connection + Send + 'static, +{ + if m.has_broken(&mut *conn.inner()) { + return Err(ConnectionError::Connection( + diesel::r2d2::Error::ConnectionError(diesel::ConnectionError::BadConnection( + "connection brokenn".to_string(), + )), + )); + } + Ok(()) +} + +#[cfg(not(feature = "use_has_broken_as_valid_check"))] +fn closure_for_is_valid_of_manager( + m: &r2d2::ConnectionManager, + conn: Connection, +) -> Result<(), ConnectionError> +where + T: R2D2Connection + Send + 'static, +{ + m.is_valid(&mut *conn.inner())?; + Ok(()) +}