diff --git a/src/drivers/trino/mod.rs b/src/drivers/trino/mod.rs index 08c287f..0a4f37b 100644 --- a/src/drivers/trino/mod.rs +++ b/src/drivers/trino/mod.rs @@ -1,6 +1,6 @@ //! Trino and maybe Presto driver. -use std::{fmt, str::FromStr, sync::Arc}; +use std::{fmt, str::FromStr, sync::Arc, time::Duration}; use async_trait::async_trait; use codespan_reporting::{diagnostic::Diagnostic, files::Files}; @@ -8,6 +8,7 @@ use joinery_macros::sql_quote; use once_cell::sync::Lazy; use prusto::{error::Error as PrustoError, Client, ClientBuilder, Presto, QueryError, Row}; use regex::Regex; +use tokio::time::sleep; use tracing::debug; use crate::{ @@ -71,6 +72,27 @@ fn rewrite_approx_quantiles(call: &FunctionCall) -> TokenStream { } } +/// Quick and dirty retry loop to deal with transient Trino errors. We will +/// probably ultimately want exponential backoff, etc. +macro_rules! retry_trino_error { + ($e:expr) => {{ + let mut max_tries = 3; + let mut sleep_duration = Duration::from_millis(500); + loop { + match $e { + Ok(val) => break Ok(val), + Err(e) if should_retry(&e) && max_tries > 0 => { + sleep(sleep_duration).await; + max_tries -= 1; + sleep_duration *= 2; + continue; + } + Err(e) => break Err(e), + } + } + }}; +} + /// A locator for a Trino database. May or may not also work for Presto. #[derive(Debug)] pub struct TrinoLocator { @@ -170,10 +192,10 @@ impl Driver for TrinoDriver { #[tracing::instrument(skip_all)] async fn execute_native_sql_statement(&mut self, sql: &str) -> Result<()> { debug!(%sql, "Executing native SQL statement"); - self.client - .execute(sql.to_owned()) - .await - .map_err(|err| abbreviate_trino_error(sql, err))?; + retry_trino_error! { + self.client.execute(sql.to_owned()).await + } + .map_err(|err| abbreviate_trino_error(sql, err))?; Ok(()) } @@ -205,11 +227,11 @@ impl Driver for TrinoDriver { #[tracing::instrument(skip(self))] async fn drop_table_if_exists(&mut self, table_name: &str) -> Result<()> { let sql = format!("DROP TABLE IF EXISTS {}", AnsiIdent(table_name)); - self.client - .execute(sql.clone()) - .await - .map_err(|err| abbreviate_trino_error(&sql, err)) - .with_context(|| format!("Failed to drop table: {}", table_name))?; + retry_trino_error! { + self.client.execute(sql.clone()).await + } + .map_err(|err| abbreviate_trino_error(&sql, err)) + .with_context(|| format!("Failed to drop table: {}", table_name))?; Ok(()) } @@ -243,12 +265,12 @@ impl DriverImpl for TrinoDriver { TrinoString(&self.schema), TrinoString(table_name) ); - Ok(self - .client - .get_all::