diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml index c8b3c70..e4f6e9f 100644 --- a/.github/workflows/rust.yml +++ b/.github/workflows/rust.yml @@ -23,6 +23,14 @@ jobs: - name: Check style run: cargo fmt -- --check + check-without-cockroach: + runs-on: ubuntu-latest + steps: + # actions/checkout@v2 + - uses: actions/checkout@72f2cec99f417b1a1c5e2e88945068983b7965f9 + - name: Cargo check + run: cargo check --no-default-features + build-and-test: runs-on: ${{ matrix.os }} strategy: diff --git a/Cargo.toml b/Cargo.toml index 7102210..77eb826 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,11 @@ license = "MIT" repository = "https://github.com/oxidecomputer/async-bb8-diesel" keywords = ["diesel", "r2d2", "pool", "tokio", "async"] +[features] +# Enables CockroachDB-specific functions. +cockroach = [] +default = [ "cockroach" ] + [dependencies] bb8 = "0.8" async-trait = "0.1.73" diff --git a/src/async_traits.rs b/src/async_traits.rs index 5ed2fe2..0e50d1a 100644 --- a/src/async_traits.rs +++ b/src/async_traits.rs @@ -12,7 +12,6 @@ use diesel::{ methods::{ExecuteDsl, LimitDsl, LoadQuery}, RunQueryDsl, }, - result::DatabaseErrorKind, result::Error as DieselError, }; use std::future::Future; @@ -29,10 +28,11 @@ where async fn batch_execute_async(&self, query: &str) -> Result<(), DieselError>; } +#[cfg(feature = "cockroach")] fn retryable_error(err: &DieselError) -> bool { match err { DieselError::DatabaseError(kind, boxed_error_information) => match kind { - DatabaseErrorKind::SerializationFailure => { + diesel::result::DatabaseErrorKind::SerializationFailure => { return boxed_error_information .message() .starts_with("restart transaction"); @@ -48,16 +48,14 @@ fn retryable_error(err: &DieselError) -> bool { pub trait AsyncConnection: AsyncSimpleConnection where Conn: 'static + DieselConnection, - Self: Send, + Self: Send + Sized + 'static, { - type OwnedConnection: Sync + Send + 'static; - #[doc(hidden)] - fn get_owned_connection(&self) -> Self::OwnedConnection; + fn get_owned_connection(&self) -> Self; #[doc(hidden)] - fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn>; + fn as_sync_conn(&self) -> MutexGuard<'_, Conn>; #[doc(hidden)] - fn as_async_conn(owned: &Self::OwnedConnection) -> &SingleConnection; + fn as_async_conn(&self) -> &SingleConnection; /// Runs the function `f` in an context where blocking is safe. async fn run(&self, f: Func) -> Result @@ -67,41 +65,36 @@ where Func: FnOnce(&mut Conn) -> Result + Send + 'static, { let connection = self.get_owned_connection(); - Self::run_with_connection(connection, f).await + connection.run_with_connection(f).await } #[doc(hidden)] - async fn run_with_connection( - connection: Self::OwnedConnection, - f: Func, - ) -> Result + async fn run_with_connection(self, f: Func) -> Result where R: Send + 'static, E: Send + 'static, Func: FnOnce(&mut Conn) -> Result + Send + 'static, { - spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection))) + spawn_blocking(move || f(&mut *self.as_sync_conn())) .await .unwrap() // Propagate panics } #[doc(hidden)] - async fn run_with_shared_connection( - connection: Arc, - f: Func, - ) -> Result + async fn run_with_shared_connection(self: &Arc, f: Func) -> Result where R: Send + 'static, E: Send + 'static, Func: FnOnce(&mut Conn) -> Result + Send + 'static, { - spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection))) + let conn = self.clone(); + spawn_blocking(move || f(&mut *conn.as_sync_conn())) .await .unwrap() // Propagate panics } #[doc(hidden)] - async fn transaction_depth2(&self) -> Result { + async fn transaction_depth(&self) -> Result { let conn = self.get_owned_connection(); Self::run_with_connection(conn, |conn| { @@ -118,60 +111,38 @@ where } #[doc(hidden)] - fn transaction_depth(conn: &Self::OwnedConnection) -> Result { - // Verifying pre-requisites: Ensure we aren't already running - // in a "broken" transaction state, nor in a nested transation. - match Conn::TransactionManager::transaction_manager_status_mut(&mut *Self::as_sync_conn( - &conn, - )) { - TransactionManagerStatus::Valid(status) => { - return Ok(status.transaction_depth().map(|d| d.into()).unwrap_or(0)); - } - TransactionManagerStatus::InError => { - return Err(DieselError::BrokenTransactionManager); - } - } - } - - #[doc(hidden)] - async fn start_transaction(conn: &Arc) -> Result<(), DieselError> { - if Self::transaction_depth(conn)? != 0 { + async fn start_transaction(self: &Arc) -> Result<(), DieselError> { + if self.transaction_depth().await? != 0 { return Err(DieselError::AlreadyInTransaction); } - Self::run_with_shared_connection(conn.clone(), |conn| { - Conn::TransactionManager::begin_transaction(conn) - }) - .await?; + self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn)) + .await?; Ok(()) } #[doc(hidden)] - async fn add_retry_savepoint(conn: &Arc) -> Result<(), DieselError> { - match Self::transaction_depth(conn)? { + async fn add_retry_savepoint(self: &Arc) -> Result<(), DieselError> { + match self.transaction_depth().await? { 0 => return Err(DieselError::NotInTransaction), 1 => (), _ => return Err(DieselError::AlreadyInTransaction), }; - Self::run_with_shared_connection(conn.clone(), |conn| { - Conn::TransactionManager::begin_transaction(conn) - }) - .await?; + self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn)) + .await?; Ok(()) } #[doc(hidden)] - async fn commit_transaction(conn: &Arc) -> Result<(), DieselError> { - Self::run_with_shared_connection(conn.clone(), |conn| { - Conn::TransactionManager::commit_transaction(conn) - }) - .await?; + async fn commit_transaction(self: &Arc) -> Result<(), DieselError> { + self.run_with_shared_connection(|conn| Conn::TransactionManager::commit_transaction(conn)) + .await?; Ok(()) } #[doc(hidden)] - async fn rollback_transaction(conn: &Arc) -> Result<(), DieselError> { - Self::run_with_shared_connection(conn.clone(), |conn| { + async fn rollback_transaction(self: &Arc) -> Result<(), DieselError> { + self.run_with_shared_connection(|conn| { Conn::TransactionManager::rollback_transaction(conn) }) .await?; @@ -181,6 +152,7 @@ where /// Issues a function `f` as a transaction. /// /// If it fails, asynchronously calls `retry` to decide if to retry. + #[cfg(feature = "cockroach")] async fn transaction_async_with_retry( &'a self, f: Func, @@ -206,7 +178,7 @@ where // that'll require more interaction with how sessions with the database // are constructed. Self::start_transaction(&conn).await?; - Self::run_with_shared_connection(conn.clone(), |conn| { + conn.run_with_shared_connection(|conn| { conn.batch_execute("SET LOCAL force_savepoint_restart = true") }) .await?; @@ -277,7 +249,7 @@ where // // However, it modifies all callsites to instead issue // known-to-be-synchronous operations from an asynchronous context. - Self::run_with_shared_connection(conn.clone(), |conn| { + conn.run_with_shared_connection(|conn| { Conn::TransactionManager::begin_transaction(conn).map_err(E::from) }) .await?; @@ -297,17 +269,18 @@ where let async_conn = SingleConnection(Self::as_async_conn(&conn).0.clone()); match f(async_conn).await { Ok(value) => { - Self::run_with_shared_connection(conn.clone(), |conn| { + conn.run_with_shared_connection(|conn| { Conn::TransactionManager::commit_transaction(conn).map_err(E::from) }) .await?; Ok(value) } Err(user_error) => { - match Self::run_with_shared_connection(conn.clone(), |conn| { - Conn::TransactionManager::rollback_transaction(conn).map_err(E::from) - }) - .await + match conn + .run_with_shared_connection(|conn| { + Conn::TransactionManager::rollback_transaction(conn).map_err(E::from) + }) + .await { Ok(()) => Err(user_error), Err(err) => Err(err), diff --git a/src/connection.rs b/src/connection.rs index 66f9f02..1f00b10 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -50,20 +50,19 @@ where Conn: 'static + R2D2Connection, Connection: crate::AsyncSimpleConnection, { - type OwnedConnection = Self; - - fn get_owned_connection(&self) -> Self::OwnedConnection { + fn get_owned_connection(&self) -> Self { Connection(self.0.clone()) } // Accesses the connection synchronously, protected by a mutex. // // Avoid calling from asynchronous contexts. - fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn> { - owned.inner() + fn as_sync_conn(&self) -> MutexGuard<'_, Conn> { + self.inner() } - fn as_async_conn(owned: &Self::OwnedConnection) -> &Connection { - owned + // TODO: Consider removing me. + fn as_async_conn(&self) -> &Connection { + self } } diff --git a/tests/test.rs b/tests/test.rs index bd7af80..91ae1ec 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -164,8 +164,10 @@ async fn test_transaction_automatic_retry_success_case() { use user::dsl; // Transaction that can retry but does not need to. + assert_eq!(conn.transaction_depth().await.unwrap(), 0); conn.transaction_async_with_retry( |conn| async move { + assert!(conn.transaction_depth().await.unwrap() > 0); diesel::insert_into(dsl::user) .values((dsl::id.eq(3), dsl::name.eq("Sally"))) .execute_async(&conn) @@ -176,6 +178,7 @@ async fn test_transaction_automatic_retry_success_case() { ) .await .expect("Transaction failed"); + assert_eq!(conn.transaction_depth().await.unwrap(), 0); test_end(crdb).await; } @@ -197,6 +200,7 @@ async fn test_transaction_automatic_retry_explicit_rollback() { // // 1. Retries on the first call // 2. Explicitly rolls back on the second call + assert_eq!(conn.transaction_depth().await.unwrap(), 0); let err = conn .transaction_async_with_retry( |_conn| { @@ -225,6 +229,7 @@ async fn test_transaction_automatic_retry_explicit_rollback() { .expect_err("Transaction should have failed"); assert_eq!(err, diesel::result::Error::RollbackTransaction); + assert_eq!(conn.transaction_depth().await.unwrap(), 0); // The transaction closure should have been attempted twice, but // we should have only asked whether or not to retry once -- after @@ -259,6 +264,7 @@ async fn test_transaction_automatic_retry_injected_errors() { conn.batch_execute_async("SET inject_retry_errors_enabled = true") .await .expect("Failed to inject error"); + assert_eq!(conn.transaction_depth().await.unwrap(), 0); conn.transaction_async_with_retry( |conn| { let transaction_attempted_count = transaction_attempted_count.clone(); @@ -280,6 +286,7 @@ async fn test_transaction_automatic_retry_injected_errors() { ) .await .expect("Transaction should have succeeded"); + assert_eq!(conn.transaction_depth().await.unwrap(), 0); // The transaction closure should have been attempted twice, but // we should have only asked whether or not to retry once -- after @@ -307,6 +314,7 @@ async fn test_transaction_automatic_retry_does_not_retry_non_retryable_errors() // Test a transaction that: // // Fails with a non-retryable error. It should exit immediately. + assert_eq!(conn.transaction_depth().await.unwrap(), 0); assert_eq!( conn.transaction_async_with_retry( |_| async { Err::<(), _>(diesel::result::Error::NotFound) }, @@ -316,6 +324,7 @@ async fn test_transaction_automatic_retry_does_not_retry_non_retryable_errors() .expect_err("Transaction should have failed"), diesel::result::Error::NotFound, ); + assert_eq!(conn.transaction_depth().await.unwrap(), 0); test_end(crdb).await; } @@ -329,6 +338,7 @@ async fn test_transaction_automatic_retry_nested_transactions_fail() { let conn = pool.get().await.unwrap(); // This outer transaction should succeed immediately... + assert_eq!(conn.transaction_depth().await.unwrap(), 0); conn.transaction_async_with_retry( |conn| async move { // ... but this inner transaction should fail! We do not support @@ -351,6 +361,7 @@ async fn test_transaction_automatic_retry_nested_transactions_fail() { ) .await .expect("Transaction should have succeeded"); + assert_eq!(conn.transaction_depth().await.unwrap(), 0); test_end(crdb).await; }