Skip to content

Commit

Permalink
please work
Browse files Browse the repository at this point in the history
Created using spr 1.3.6-beta.1
  • Loading branch information
sunshowers committed Oct 15, 2024
1 parent d9147e2 commit 01be5c0
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 119 deletions.
146 changes: 58 additions & 88 deletions src/async_traits.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,13 @@ where
Conn: 'static + DieselConnection + R2D2Connection,
Self: Send + Sized + 'static,
{
async fn ping_async(&mut self) -> Result<(), RunError<DieselError>> {
async fn ping_async(&mut self) -> Result<(), RunError> {
self.as_async_conn().run(|conn| conn.ping()).await
}

async fn is_broken_async(&mut self) -> bool {
self.as_async_conn()
.run(|conn| Ok::<bool, ()>(conn.is_broken()))
.run(|conn| Ok::<bool, _>(conn.is_broken()))
.await
.unwrap()
}
Expand All @@ -75,42 +75,36 @@ where
fn as_async_conn(&self) -> &Connection<Conn>;

/// Runs the function `f` in an context where blocking is safe.
async fn run<R, E, Func>(&self, f: Func) -> Result<R, RunError<E>>
async fn run<R, Func>(&self, f: Func) -> Result<R, RunError>
where
R: Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, DieselError> + Send + 'static,
{
let connection = self.get_owned_connection();
connection.run_with_connection(f).await
}

#[doc(hidden)]
async fn run_with_connection<R, E, Func>(self, f: Func) -> Result<R, RunError<E>>
async fn run_with_connection<R, Func>(self, f: Func) -> Result<R, RunError>
where
R: Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, DieselError> + Send + 'static,
{
handle_spawn_blocking_error(spawn_blocking(move || f(&mut *self.as_sync_conn())).await)
}

#[doc(hidden)]
async fn run_with_shared_connection<R, E, Func>(
self: &Arc<Self>,
f: Func,
) -> Result<R, RunError<E>>
async fn run_with_shared_connection<R, Func>(self: &Arc<Self>, f: Func) -> Result<R, RunError>
where
R: Send + 'static,
E: Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, E> + Send + 'static,
Func: FnOnce(&mut Conn) -> Result<R, DieselError> + Send + 'static,
{
let conn = self.clone();
handle_spawn_blocking_error(spawn_blocking(move || f(&mut *conn.as_sync_conn())).await)
}

#[doc(hidden)]
async fn transaction_depth(&self) -> Result<u32, RunError<DieselError>> {
async fn transaction_depth(&self) -> Result<u32, RunError> {
let conn = self.get_owned_connection();

Self::run_with_connection(conn, |conn| {
Expand All @@ -130,9 +124,9 @@ where
// This method is a wrapper around that call, with validation that
// we're actually issuing the BEGIN statement here.
#[doc(hidden)]
async fn start_transaction(self: &Arc<Self>) -> Result<(), RunError<DieselError>> {
async fn start_transaction(self: &Arc<Self>) -> Result<(), RunError> {
if self.transaction_depth().await? != 0 {
return Err(RunError::User(DieselError::AlreadyInTransaction));
return Err(RunError::DieselError(DieselError::AlreadyInTransaction));
}
self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn))
.await?;
Expand All @@ -145,11 +139,11 @@ where
// This method is a wrapper around that call, with validation that
// we're actually issuing our first SAVEPOINT here.
#[doc(hidden)]
async fn add_retry_savepoint(self: &Arc<Self>) -> Result<(), RunError<DieselError>> {
async fn add_retry_savepoint(self: &Arc<Self>) -> Result<(), RunError> {
match self.transaction_depth().await? {
0 => return Err(RunError::User(DieselError::NotInTransaction)),
0 => return Err(RunError::DieselError(DieselError::NotInTransaction)),
1 => (),
_ => return Err(RunError::User(DieselError::AlreadyInTransaction)),
_ => return Err(RunError::DieselError(DieselError::AlreadyInTransaction)),
};

self.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn))
Expand All @@ -158,14 +152,14 @@ where
}

#[doc(hidden)]
async fn commit_transaction(self: &Arc<Self>) -> Result<(), RunError<DieselError>> {
async fn commit_transaction(self: &Arc<Self>) -> Result<(), RunError> {
self.run_with_shared_connection(|conn| Conn::TransactionManager::commit_transaction(conn))
.await?;
Ok(())
}

#[doc(hidden)]
async fn rollback_transaction(self: &Arc<Self>) -> Result<(), RunError<DieselError>> {
async fn rollback_transaction(self: &Arc<Self>) -> Result<(), RunError> {
self.run_with_shared_connection(|conn| {
Conn::TransactionManager::rollback_transaction(conn)
})
Expand All @@ -184,10 +178,10 @@ where
&'a self,
f: Func,
retry: RetryFunc,
) -> Result<R, RunError<DieselError>>
) -> Result<R, RunError>
where
R: Any + Send + 'static,
Fut: FutureExt<Output = Result<R, RunError<DieselError>>> + Send,
Fut: FutureExt<Output = Result<R, RunError>> + Send,
Func: (Fn(Connection<Conn>) -> Fut) + Send + Sync,
RetryFut: FutureExt<Output = bool> + Send,
RetryFunc: Fn() -> RetryFut + Send + Sync,
Expand Down Expand Up @@ -220,13 +214,11 @@ where
#[cfg(feature = "cockroach")]
async fn transaction_async_with_retry_inner(
&self,
f: &(dyn Fn(
Connection<Conn>,
) -> BoxFuture<'_, Result<Box<dyn Any + Send>, RunError<DieselError>>>
f: &(dyn Fn(Connection<Conn>) -> BoxFuture<'_, Result<Box<dyn Any + Send>, RunError>>
+ Send
+ Sync),
retry: &(dyn Fn() -> BoxFuture<'_, bool> + Send + Sync),
) -> Result<Box<dyn Any + Send>, RunError<DieselError>> {
) -> Result<Box<dyn Any + Send>, RunError> {
// Check out a connection once, and use it for the duration of the
// operation.
let conn = Arc::new(self.get_owned_connection());
Expand Down Expand Up @@ -264,7 +256,7 @@ where
// We're still in the transaction, but we at least
// tried to ROLLBACK to our savepoint.
let retried = match &err {
RunError::User(err) => retryable_error(err) && retry().await,
RunError::DieselError(err) => retryable_error(err) && retry().await,
RunError::RuntimeShutdown => false,
};
if retried {
Expand All @@ -282,7 +274,7 @@ where
Self::commit_transaction(&conn).await?;
return Ok(value);
}
Err(RunError::User(user_error)) => {
Err(RunError::DieselError(user_error)) => {
// The user-level operation failed: ROLLBACK to the retry
// savepoint.
if let Err(first_rollback_err) = Self::rollback_transaction(&conn).await {
Expand All @@ -302,7 +294,7 @@ where

// If we aren't retrying, ROLLBACK the BEGIN statement too.
return match Self::rollback_transaction(&conn).await {
Ok(()) => Err(RunError::User(user_error)),
Ok(()) => Err(RunError::DieselError(user_error)),
Err(err) => Err(err),
};
}
Expand All @@ -321,7 +313,7 @@ where
async fn transaction_async<R, E, Func, Fut, 'a>(&'a self, f: Func) -> Result<R, E>
where
R: Send + 'static,
E: From<RunError<DieselError>> + Send + 'static,
E: From<RunError> + Send + 'static,
Fut: Future<Output = Result<R, E>> + Send,
Func: FnOnce(Connection<Conn>) -> Fut + Send,
{
Expand Down Expand Up @@ -354,7 +346,7 @@ where
>,
) -> Result<Box<dyn Any + Send>, E>
where
E: From<RunError<DieselError>> + Send + 'static,
E: From<RunError> + Send + 'static,
{
// Check out a connection once, and use it for the duration of the
// operation.
Expand All @@ -365,15 +357,8 @@ where
//
// However, it modifies all callsites to instead issue
// known-to-be-synchronous operations from an asynchronous context.
conn.run_with_shared_connection(|conn| {
Conn::TransactionManager::begin_transaction(conn)
.map_err(|err| E::from(RunError::User(err)))
})
.await
.map_err(|err| match err {
RunError::User(err) => err,
RunError::RuntimeShutdown => RunError::RuntimeShutdown.into(),
})?;
conn.run_with_shared_connection(|conn| Conn::TransactionManager::begin_transaction(conn))
.await?;

// TODO: The ideal interface would pass the "async_conn" object to the
// underlying function "f" by reference.
Expand All @@ -390,42 +375,29 @@ where
let async_conn = Connection(Self::as_async_conn(&conn).0.clone());
match f(async_conn).await {
Ok(value) => {
match conn
.run_with_shared_connection(|conn| {
Conn::TransactionManager::commit_transaction(conn)
.map_err(|err| E::from(RunError::User(err)))
})
.await
{
Ok(()) => Ok(value),
// XXX: we should try to roll this back
Err(RunError::User(err)) => Err(err),
Err(RunError::RuntimeShutdown) => Err(RunError::RuntimeShutdown.into()),
}
conn.run_with_shared_connection(|conn| {
Conn::TransactionManager::commit_transaction(conn)
})
.await?;
Ok(value)
}
Err(user_error) => {
match conn
.run_with_shared_connection(|conn| {
Conn::TransactionManager::rollback_transaction(conn)
.map_err(|err| E::from(RunError::User(err)))
})
.await
{
Ok(()) => Err(user_error),
Err(RunError::User(err)) => Err(err),
Err(RunError::RuntimeShutdown) => Err(RunError::RuntimeShutdown.into()),
}
conn.run_with_shared_connection(|conn| {
Conn::TransactionManager::rollback_transaction(conn)
})
.await?;
Err(user_error)
}
}
}
}

fn handle_spawn_blocking_error<T, E>(
result: Result<Result<T, E>, JoinError>,
) -> Result<T, RunError<E>> {
fn handle_spawn_blocking_error<T>(
result: Result<Result<T, DieselError>, JoinError>,
) -> Result<T, RunError> {
match result {
Ok(Ok(v)) => Ok(v),
Ok(Err(err)) => Err(RunError::User(err)),
Ok(Err(err)) => Err(RunError::DieselError(err)),
Err(err) => {
if err.is_cancelled() {
// The only way a spawn_blocking task can be marked cancelled
Expand All @@ -438,7 +410,11 @@ fn handle_spawn_blocking_error<T, E>(
} else {
// Not possible to reach this as of Tokio 1.40, but maybe in
// future versions.
panic!("unexpected JoinError: {:?}", err);
panic!(
"unexpected JoinError, did a new version of Tokio add \
a source other than panics and cancellations? {:?}",
err
);
}
}
}
Expand All @@ -450,26 +426,26 @@ pub trait AsyncRunQueryDsl<Conn, AsyncConn>
where
Conn: 'static + DieselConnection,
{
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, RunError<DieselError>>
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, RunError>
where
Self: ExecuteDsl<Conn>;

async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, RunError<DieselError>>
async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, RunError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>;

async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, RunError<DieselError>>
async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, RunError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>;

async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, RunError<DieselError>>
async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, RunError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>;

async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, RunError<DieselError>>
async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, RunError>
where
U: Send + 'static,
Self: LimitDsl,
Expand All @@ -483,38 +459,38 @@ where
Conn: 'static + DieselConnection,
AsyncConn: Send + Sync + AsyncConnection<Conn>,
{
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, RunError<DieselError>>
async fn execute_async(self, asc: &AsyncConn) -> Result<usize, RunError>
where
Self: ExecuteDsl<Conn>,
{
asc.run(|conn| self.execute(conn)).await
}

async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, RunError<DieselError>>
async fn load_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, RunError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>,
{
asc.run(|conn| self.load(conn)).await
}

async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, RunError<DieselError>>
async fn get_result_async<U>(self, asc: &AsyncConn) -> Result<U, RunError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>,
{
asc.run(|conn| self.get_result(conn)).await
}

async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, RunError<DieselError>>
async fn get_results_async<U>(self, asc: &AsyncConn) -> Result<Vec<U>, RunError>
where
U: Send + 'static,
Self: LoadQuery<'static, Conn, U>,
{
asc.run(|conn| self.get_results(conn)).await
}

async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, RunError<DieselError>>
async fn first_async<U>(self, asc: &AsyncConn) -> Result<U, RunError>
where
U: Send + 'static,
Self: LimitDsl,
Expand All @@ -529,10 +505,7 @@ pub trait AsyncSaveChangesDsl<Conn, AsyncConn>
where
Conn: 'static + DieselConnection,
{
async fn save_changes_async<Output>(
self,
asc: &AsyncConn,
) -> Result<Output, RunError<DieselError>>
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, RunError>
where
Self: Sized,
Conn: diesel::query_dsl::UpdateAndFetchResults<Self, Output>,
Expand All @@ -546,10 +519,7 @@ where
Conn: 'static + DieselConnection,
AsyncConn: Send + Sync + AsyncConnection<Conn>,
{
async fn save_changes_async<Output>(
self,
asc: &AsyncConn,
) -> Result<Output, RunError<DieselError>>
async fn save_changes_async<Output>(self, asc: &AsyncConn) -> Result<Output, RunError>
where
Conn: diesel::query_dsl::UpdateAndFetchResults<Self, Output>,
Output: Send + 'static,
Expand Down
Loading

0 comments on commit 01be5c0

Please sign in to comment.