Skip to content

Commit

Permalink
Add cockroach feature flag, simplify connection management, more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
smklein committed Nov 14, 2023
1 parent 858e5ed commit 14bd98b
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 69 deletions.
8 changes: 8 additions & 0 deletions .github/workflows/rust.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,14 @@ jobs:
- name: Check style
run: cargo fmt -- --check

check-without-cockroach:
runs-on: ubuntu-latest
steps:
# actions/checkout@v2
- uses: actions/checkout@72f2cec99f417b1a1c5e2e88945068983b7965f9
- name: Cargo check
run: cargo check --no-default-features

build-and-test:
runs-on: ${{ matrix.os }}
strategy:
Expand Down
5 changes: 5 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ license = "MIT"
repository = "https://github.com/oxidecomputer/async-bb8-diesel"
keywords = ["diesel", "r2d2", "pool", "tokio", "async"]

[features]
# Enables CockroachDB-specific functions.
cockroach = []
default = [ "cockroach" ]

[dependencies]
bb8 = "0.8"
async-trait = "0.1.73"
Expand Down
97 changes: 35 additions & 62 deletions src/async_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ use diesel::{
methods::{ExecuteDsl, LimitDsl, LoadQuery},
RunQueryDsl,
},
result::DatabaseErrorKind,
result::Error as DieselError,
};
use std::future::Future;
Expand All @@ -29,10 +28,11 @@ where
async fn batch_execute_async(&self, query: &str) -> Result<(), DieselError>;
}

#[cfg(feature = "cockroach")]
fn retryable_error(err: &DieselError) -> bool {
match err {
DieselError::DatabaseError(kind, boxed_error_information) => match kind {
DatabaseErrorKind::SerializationFailure => {
diesel::result::DatabaseErrorKind::SerializationFailure => {
return boxed_error_information
.message()
.starts_with("restart transaction");
Expand All @@ -48,16 +48,14 @@ fn retryable_error(err: &DieselError) -> bool {
pub trait AsyncConnection<Conn>: AsyncSimpleConnection<Conn>
where
Conn: 'static + DieselConnection,
Self: Send,
Self: Send + Sized + 'static,
{
type OwnedConnection: Sync + Send + 'static;

#[doc(hidden)]
fn get_owned_connection(&self) -> Self::OwnedConnection;
fn get_owned_connection(&self) -> Self;
#[doc(hidden)]
fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn>;
fn as_sync_conn(&self) -> MutexGuard<'_, Conn>;
#[doc(hidden)]
fn as_async_conn(owned: &Self::OwnedConnection) -> &SingleConnection<Conn>;
fn as_async_conn(&self) -> &SingleConnection<Conn>;

/// Runs the function `f` in an context where blocking is safe.
async fn run<R, E, Func>(&self, f: Func) -> Result<R, E>
Expand All @@ -67,41 +65,36 @@ where
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
let connection = self.get_owned_connection();
Self::run_with_connection(connection, f).await
connection.run_with_connection(f).await
}

#[doc(hidden)]
async fn run_with_connection<R, E, Func>(
connection: Self::OwnedConnection,
f: Func,
) -> Result<R, E>
async fn run_with_connection<R, E, Func>(self, f: Func) -> Result<R, E>
where
R: Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection)))
spawn_blocking(move || f(&mut *self.as_sync_conn()))
.await
.unwrap() // Propagate panics
}

#[doc(hidden)]
async fn run_with_shared_connection<R, E, Func>(
connection: Arc<Self::OwnedConnection>,
f: Func,
) -> Result<R, E>
async fn run_with_shared_connection<R, E, Func>(self: &Arc<Self>, f: Func) -> Result<R, E>
where
R: Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
{
spawn_blocking(move || f(&mut *Self::as_sync_conn(&connection)))
let conn = self.clone();
spawn_blocking(move || f(&mut *conn.as_sync_conn()))
.await
.unwrap() // Propagate panics
}

#[doc(hidden)]
async fn transaction_depth2(&self) -> Result<u32, DieselError> {
async fn transaction_depth(&self) -> Result<u32, DieselError> {
let conn = self.get_owned_connection();

Self::run_with_connection(conn, |conn| {
Expand All @@ -118,60 +111,38 @@ where
}

#[doc(hidden)]
fn transaction_depth(conn: &Self::OwnedConnection) -> Result<u32, DieselError> {
// Verifying pre-requisites: Ensure we aren't already running
// in a "broken" transaction state, nor in a nested transation.
match Conn::TransactionManager::transaction_manager_status_mut(&mut *Self::as_sync_conn(
&conn,
)) {
TransactionManagerStatus::Valid(status) => {
return Ok(status.transaction_depth().map(|d| d.into()).unwrap_or(0));
}
TransactionManagerStatus::InError => {
return Err(DieselError::BrokenTransactionManager);
}
}
}

#[doc(hidden)]
async fn start_transaction(conn: &Arc<Self::OwnedConnection>) -> Result<(), DieselError> {
if Self::transaction_depth(conn)? != 0 {
async fn start_transaction(self: &Arc<Self>) -> Result<(), DieselError> {
if self.transaction_depth().await? != 0 {
return Err(DieselError::AlreadyInTransaction);
}
Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::begin_transaction(conn)
})
.await?;
self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn))
.await?;
Ok(())
}

#[doc(hidden)]
async fn add_retry_savepoint(conn: &Arc<Self::OwnedConnection>) -> Result<(), DieselError> {
match Self::transaction_depth(conn)? {
async fn add_retry_savepoint(self: &Arc<Self>) -> Result<(), DieselError> {
match self.transaction_depth().await? {
0 => return Err(DieselError::NotInTransaction),
1 => (),
_ => return Err(DieselError::AlreadyInTransaction),
};

Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::begin_transaction(conn)
})
.await?;
self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn))
.await?;
Ok(())
}

#[doc(hidden)]
async fn commit_transaction(conn: &Arc<Self::OwnedConnection>) -> Result<(), DieselError> {
Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::commit_transaction(conn)
})
.await?;
async fn commit_transaction(self: &Arc<Self>) -> Result<(), DieselError> {
self.run_with_shared_connection(|conn| Conn::TransactionManager::commit_transaction(conn))
.await?;
Ok(())
}

#[doc(hidden)]
async fn rollback_transaction(conn: &Arc<Self::OwnedConnection>) -> Result<(), DieselError> {
Self::run_with_shared_connection(conn.clone(), |conn| {
async fn rollback_transaction(self: &Arc<Self>) -> Result<(), DieselError> {
self.run_with_shared_connection(|conn| {
Conn::TransactionManager::rollback_transaction(conn)
})
.await?;
Expand All @@ -181,6 +152,7 @@ where
/// Issues a function `f` as a transaction.
///
/// If it fails, asynchronously calls `retry` to decide if to retry.
#[cfg(feature = "cockroach")]
async fn transaction_async_with_retry<R, Func, Fut, RetryFut, RetryFunc, 'a>(
&'a self,
f: Func,
Expand All @@ -206,7 +178,7 @@ where
// that'll require more interaction with how sessions with the database
// are constructed.
Self::start_transaction(&conn).await?;
Self::run_with_shared_connection(conn.clone(), |conn| {
conn.run_with_shared_connection(|conn| {
conn.batch_execute("SET LOCAL force_savepoint_restart = true")
})
.await?;
Expand Down Expand Up @@ -277,7 +249,7 @@ where
//
// However, it modifies all callsites to instead issue
// known-to-be-synchronous operations from an asynchronous context.
Self::run_with_shared_connection(conn.clone(), |conn| {
conn.run_with_shared_connection(|conn| {
Conn::TransactionManager::begin_transaction(conn).map_err(E::from)
})
.await?;
Expand All @@ -297,17 +269,18 @@ where
let async_conn = SingleConnection(Self::as_async_conn(&conn).0.clone());
match f(async_conn).await {
Ok(value) => {
Self::run_with_shared_connection(conn.clone(), |conn| {
conn.run_with_shared_connection(|conn| {
Conn::TransactionManager::commit_transaction(conn).map_err(E::from)
})
.await?;
Ok(value)
}
Err(user_error) => {
match Self::run_with_shared_connection(conn.clone(), |conn| {
Conn::TransactionManager::rollback_transaction(conn).map_err(E::from)
})
.await
match conn
.run_with_shared_connection(|conn| {
Conn::TransactionManager::rollback_transaction(conn).map_err(E::from)
})
.await
{
Ok(()) => Err(user_error),
Err(err) => Err(err),
Expand Down
13 changes: 6 additions & 7 deletions src/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,20 +50,19 @@ where
Conn: 'static + R2D2Connection,
Connection<Conn>: crate::AsyncSimpleConnection<Conn>,
{
type OwnedConnection = Self;

fn get_owned_connection(&self) -> Self::OwnedConnection {
fn get_owned_connection(&self) -> Self {
Connection(self.0.clone())
}

// Accesses the connection synchronously, protected by a mutex.
//
// Avoid calling from asynchronous contexts.
fn as_sync_conn(owned: &Self::OwnedConnection) -> MutexGuard<'_, Conn> {
owned.inner()
fn as_sync_conn(&self) -> MutexGuard<'_, Conn> {
self.inner()
}

fn as_async_conn(owned: &Self::OwnedConnection) -> &Connection<Conn> {
owned
// TODO: Consider removing me.
fn as_async_conn(&self) -> &Connection<Conn> {
self
}
}
11 changes: 11 additions & 0 deletions tests/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,10 @@ async fn test_transaction_automatic_retry_success_case() {
use user::dsl;

// Transaction that can retry but does not need to.
assert_eq!(conn.transaction_depth().await.unwrap(), 0);
conn.transaction_async_with_retry(
|conn| async move {
assert!(conn.transaction_depth().await.unwrap() > 0);
diesel::insert_into(dsl::user)
.values((dsl::id.eq(3), dsl::name.eq("Sally")))
.execute_async(&conn)
Expand All @@ -176,6 +178,7 @@ async fn test_transaction_automatic_retry_success_case() {
)
.await
.expect("Transaction failed");
assert_eq!(conn.transaction_depth().await.unwrap(), 0);

test_end(crdb).await;
}
Expand All @@ -197,6 +200,7 @@ async fn test_transaction_automatic_retry_explicit_rollback() {
//
// 1. Retries on the first call
// 2. Explicitly rolls back on the second call
assert_eq!(conn.transaction_depth().await.unwrap(), 0);
let err = conn
.transaction_async_with_retry(
|_conn| {
Expand Down Expand Up @@ -225,6 +229,7 @@ async fn test_transaction_automatic_retry_explicit_rollback() {
.expect_err("Transaction should have failed");

assert_eq!(err, diesel::result::Error::RollbackTransaction);
assert_eq!(conn.transaction_depth().await.unwrap(), 0);

// The transaction closure should have been attempted twice, but
// we should have only asked whether or not to retry once -- after
Expand Down Expand Up @@ -259,6 +264,7 @@ async fn test_transaction_automatic_retry_injected_errors() {
conn.batch_execute_async("SET inject_retry_errors_enabled = true")
.await
.expect("Failed to inject error");
assert_eq!(conn.transaction_depth().await.unwrap(), 0);
conn.transaction_async_with_retry(
|conn| {
let transaction_attempted_count = transaction_attempted_count.clone();
Expand All @@ -280,6 +286,7 @@ async fn test_transaction_automatic_retry_injected_errors() {
)
.await
.expect("Transaction should have succeeded");
assert_eq!(conn.transaction_depth().await.unwrap(), 0);

// The transaction closure should have been attempted twice, but
// we should have only asked whether or not to retry once -- after
Expand Down Expand Up @@ -307,6 +314,7 @@ async fn test_transaction_automatic_retry_does_not_retry_non_retryable_errors()
// Test a transaction that:
//
// Fails with a non-retryable error. It should exit immediately.
assert_eq!(conn.transaction_depth().await.unwrap(), 0);
assert_eq!(
conn.transaction_async_with_retry(
|_| async { Err::<(), _>(diesel::result::Error::NotFound) },
Expand All @@ -316,6 +324,7 @@ async fn test_transaction_automatic_retry_does_not_retry_non_retryable_errors()
.expect_err("Transaction should have failed"),
diesel::result::Error::NotFound,
);
assert_eq!(conn.transaction_depth().await.unwrap(), 0);

test_end(crdb).await;
}
Expand All @@ -329,6 +338,7 @@ async fn test_transaction_automatic_retry_nested_transactions_fail() {
let conn = pool.get().await.unwrap();

// This outer transaction should succeed immediately...
assert_eq!(conn.transaction_depth().await.unwrap(), 0);
conn.transaction_async_with_retry(
|conn| async move {
// ... but this inner transaction should fail! We do not support
Expand All @@ -351,6 +361,7 @@ async fn test_transaction_automatic_retry_nested_transactions_fail() {
)
.await
.expect("Transaction should have succeeded");
assert_eq!(conn.transaction_depth().await.unwrap(), 0);

test_end(crdb).await;
}
Expand Down

0 comments on commit 14bd98b

Please sign in to comment.