From 05b9e533bc44b2b0655165eef55354d485cd208c Mon Sep 17 00:00:00 2001 From: Eric Kidd Date: Thu, 29 Feb 2024 06:05:55 -0500 Subject: [PATCH] Retry some Trino errors --- src/drivers/trino/mod.rs | 78 +++++++++++++++++++++++++++------------- 1 file changed, 53 insertions(+), 25 deletions(-) diff --git a/src/drivers/trino/mod.rs b/src/drivers/trino/mod.rs index 08c287f..0a4f37b 100644 --- a/src/drivers/trino/mod.rs +++ b/src/drivers/trino/mod.rs @@ -1,6 +1,6 @@ //! Trino and maybe Presto driver. -use std::{fmt, str::FromStr, sync::Arc}; +use std::{fmt, str::FromStr, sync::Arc, time::Duration}; use async_trait::async_trait; use codespan_reporting::{diagnostic::Diagnostic, files::Files}; @@ -8,6 +8,7 @@ use joinery_macros::sql_quote; use once_cell::sync::Lazy; use prusto::{error::Error as PrustoError, Client, ClientBuilder, Presto, QueryError, Row}; use regex::Regex; +use tokio::time::sleep; use tracing::debug; use crate::{ @@ -71,6 +72,27 @@ fn rewrite_approx_quantiles(call: &FunctionCall) -> TokenStream { } } +/// Quick and dirty retry loop to deal with transient Trino errors. We will +/// probably ultimately want exponential backoff, etc. +macro_rules! retry_trino_error { + ($e:expr) => {{ + let mut max_tries = 3; + let mut sleep_duration = Duration::from_millis(500); + loop { + match $e { + Ok(val) => break Ok(val), + Err(e) if should_retry(&e) && max_tries > 0 => { + sleep(sleep_duration).await; + max_tries -= 1; + sleep_duration *= 2; + continue; + } + Err(e) => break Err(e), + } + } + }}; +} + /// A locator for a Trino database. May or may not also work for Presto. #[derive(Debug)] pub struct TrinoLocator { @@ -170,10 +192,10 @@ impl Driver for TrinoDriver { #[tracing::instrument(skip_all)] async fn execute_native_sql_statement(&mut self, sql: &str) -> Result<()> { debug!(%sql, "Executing native SQL statement"); - self.client - .execute(sql.to_owned()) - .await - .map_err(|err| abbreviate_trino_error(sql, err))?; + retry_trino_error! { + self.client.execute(sql.to_owned()).await + } + .map_err(|err| abbreviate_trino_error(sql, err))?; Ok(()) } @@ -205,11 +227,11 @@ impl Driver for TrinoDriver { #[tracing::instrument(skip(self))] async fn drop_table_if_exists(&mut self, table_name: &str) -> Result<()> { let sql = format!("DROP TABLE IF EXISTS {}", AnsiIdent(table_name)); - self.client - .execute(sql.clone()) - .await - .map_err(|err| abbreviate_trino_error(&sql, err)) - .with_context(|| format!("Failed to drop table: {}", table_name))?; + retry_trino_error! { + self.client.execute(sql.clone()).await + } + .map_err(|err| abbreviate_trino_error(&sql, err)) + .with_context(|| format!("Failed to drop table: {}", table_name))?; Ok(()) } @@ -243,12 +265,12 @@ impl DriverImpl for TrinoDriver { TrinoString(&self.schema), TrinoString(table_name) ); - Ok(self - .client - .get_all::(sql.clone()) - .await - .map_err(|err| abbreviate_trino_error(&sql, err)) - .with_context(|| format!("Failed to get columns for table: {}", table_name))? + let dataset = retry_trino_error! { + self.client.get_all::(sql.clone()).await + } + .map_err(|err| abbreviate_trino_error(&sql, err)) + .with_context(|| format!("Failed to get columns for table: {}", table_name))?; + Ok(dataset .into_vec() .into_iter() .map(|c| Column { @@ -275,15 +297,12 @@ impl DriverImpl for TrinoDriver { AnsiIdent(table_name), cols_sql ); - let rows = self - .client - .get_all::(sql.clone()) - .await - .map_err(|err| abbreviate_trino_error(&sql, err)) - .with_context(|| format!("Failed to query table: {}", table_name))? - .into_vec() - .into_iter() - .map(|r| Ok(r.into_json())); + let dataset = retry_trino_error! { + self.client.get_all::(sql.clone()).await + } + .map_err(|err| abbreviate_trino_error(&sql, err)) + .with_context(|| format!("Failed to query table: {}", table_name))?; + let rows = dataset.into_vec().into_iter().map(|r| Ok(r.into_json())); Ok(Box::new(rows)) } } @@ -324,6 +343,15 @@ impl fmt::Display for TrinoString<'_> { } } +/// Should an error be retried? +/// +/// Note that the `rusto` crate has internal support for retrying connection +/// and network errors, so we don't need to worry about that. But we do need +/// to look out for `QueryError`s that might need to be retried. +fn should_retry(e: &PrustoError) -> bool { + matches!(e, PrustoError::QueryError(QueryError { error_type, .. }) if error_type == "NO_NODES_AVAILABLE") +} + /// These errors are pages long. fn abbreviate_trino_error(sql: &str, e: PrustoError) -> Error { if let PrustoError::QueryError(e) = &e {