From 625dcbe246b5e4e07d8a2a45b9f3ecbc06af869c Mon Sep 17 00:00:00 2001 From: Eric Kidd Date: Thu, 5 Dec 2024 19:07:17 -0500 Subject: [PATCH] Add ConnectorType to Target::Trino We will need to distinguish between Trino connectors like `memory`, `hive` and `iceberg` when emitting Trino SQL, because the different connectors support subsets of Trino's features and types. We modify the `Target` enum to include a `Target::Trino(ConnectionType)` variant instead of just `Target`. This will allow `emit` code to know what target it is generating code for. This also removes the final pretense of a "standalone" mode where we can run without access to a specific Trino server. This is fine; we're moving in that direction anyway. Note the replacement of `t == Target::Trino` with `matches!(t, Target::Trino(_))`, which is a handy Rust feature. --- Cargo.lock | 4 ++-- Cargo.toml | 7 +++++- Justfile | 10 +++++--- src/ast.rs | 28 +++++++++++++--------- src/cmd/sql_test.rs | 2 +- src/cmd/transpile.rs | 2 +- src/drivers/mod.rs | 5 ++-- src/drivers/trino/mod.rs | 50 ++++++++++++++++++++++++++++------------ 8 files changed, 72 insertions(+), 36 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 18e82ed..54bc26c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -459,9 +459,9 @@ dependencies = [ [[package]] name = "dbcrossbar_trino" -version = "0.2.2" +version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f202d41d8f1b6b50fee260837ae3a06723510ea9df3cb8724a0665385d99960" +checksum = "2c824a64db21129755b9ae5eb528dbe48c5eae4cf1be90b885e4ed873ddd7ea4" dependencies = [ "base64", "chrono", diff --git a/Cargo.toml b/Cargo.toml index 054dd23..25e6c89 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -17,7 +17,12 @@ async-trait = "0.1.73" clap = { version = "4.4.6", features = ["derive", "wrap_help"] } codespan-reporting = "0.11.1" csv = "1.2.2" -dbcrossbar_trino = { version = "0.2.2", features = ["macros", "values", "client", "rustls-tls"] } +dbcrossbar_trino = { version = "0.2.3", features = [ + "macros", + "values", + "client", + "rustls-tls", +] } derive-visitor = "0.4.0" glob = "0.3.1" joinery_macros = { path = "joinery_macros" } diff --git a/Justfile b/Justfile index 1d78ed6..e930cd2 100644 --- a/Justfile +++ b/Justfile @@ -19,7 +19,7 @@ trino-image: trino-plugin # Run our Trino image. docker-run-trino: trino-image (cd trino && docker compose up) - docker compose exec trino install-udfs + (cd trino && docker compose exec trino install-udfs) # Stop and delete our Trino container. docker-rm-trino: @@ -33,10 +33,14 @@ check: #cargo deny check cargo test -# Check Trino. Assumes `docker-run-trino` has been run. +# Check Trino (memory). Assumes `docker-run-trino` has been run. check-trino: cargo run -- sql-test --database "trino://admin@localhost/memory/default" ./tests/sql/ +# Check Trino (Hive). Assumes `docker-run-trino` has been run. +check-trino-hive: + cargo run -- sql-test --database "trino://admin@localhost/hive/default" ./tests/sql/ + # Access a Trino shell. trino-shell: - docker compose exec trino-joinery trino \ No newline at end of file + (cd trino && docker compose exec trino-joinery trino) diff --git a/src/ast.rs b/src/ast.rs index 1db1a6d..cd61692 100644 --- a/src/ast.rs +++ b/src/ast.rs @@ -24,6 +24,7 @@ use std::{ mem::take, }; +use dbcrossbar_trino::ConnectorType as TrinoConnectorType; use derive_visitor::{Drive, DriveMut}; use joinery_macros::{Emit, EmitDefault, Spanned, ToTokens}; @@ -65,8 +66,13 @@ static KEYWORDS: phf::Set<&'static str> = phf::phf_set! { #[derive(Clone, Copy, Debug, Eq, PartialEq)] #[allow(dead_code)] pub enum Target { + /// We're transpiling BigQuery to BigQuery, which is mostly useful for + /// dumping the AST back out as SQL. BigQuery, - Trino, + /// We're transpiling BigQuery to Trino. Note that we need to specify what + /// kind of Trino connector we're using, because many connectors are missing + /// certain data types, and we need to fix the SQL we emit accordingly. + Trino(TrinoConnectorType), } impl Target { @@ -74,7 +80,7 @@ impl Target { pub fn is_keyword(self, s: &str) -> bool { let keywords = match self { Target::BigQuery => &KEYWORDS, - Target::Trino => &TRINO_KEYWORDS, + Target::Trino(_) => &TRINO_KEYWORDS, }; keywords.contains(s.to_ascii_uppercase().as_str()) } @@ -84,7 +90,7 @@ impl fmt::Display for Target { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Target::BigQuery => write!(f, "bigquery"), - Target::Trino => write!(f, "trino"), + Target::Trino(_) => write!(f, "trino"), } } } @@ -187,7 +193,7 @@ impl Emit for Ident { if t.is_keyword(&self.name) || !is_c_ident(&self.name) { match t { Target::BigQuery => write!(f, "{}", BigQueryName(&self.name))?, - Target::Trino => { + Target::Trino(_) => { write!(f, "{}", AnsiIdent(&self.name))?; } } @@ -232,7 +238,7 @@ impl Emit for LiteralValue { LiteralValue::Float64(fl) => write!(f, "{}", fl), LiteralValue::String(s) => match t { Target::BigQuery => write!(f, "{}", BigQueryString(s)), - Target::Trino => write!(f, "{}", TrinoString(s)), + Target::Trino(_) => write!(f, "{}", TrinoString(s)), }, } } @@ -914,7 +920,7 @@ pub enum CastType { impl Emit for CastType { fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> io::Result<()> { match self { - CastType::SafeCast { safe_cast_token } if t == Target::Trino => { + CastType::SafeCast { safe_cast_token } if matches!(t, Target::Trino(_)) => { safe_cast_token.ident.token.with_str("TRY_CAST").emit(t, f) } _ => self.emit_default(t, f), @@ -1080,7 +1086,7 @@ impl Emit for ArrayExpression { definition: ArrayDefinition::Query { select }, delim2, .. - } if t == Target::Trino => { + } if matches!(t, Target::Trino(_)) => { // We can't do this with a transform and sql_quote because it // outputs Trino-specific closure syntax. let ArraySelectExpression { @@ -1130,7 +1136,7 @@ impl Emit for ArrayExpression { last_token.with_ws_only().emit(t, f)?; } _ => match t { - Target::Trino => { + Target::Trino(_) => { let needs_cast = self.definition.has_zero_element_expressions(); if needs_cast { f.write_token_start("CAST(")?; @@ -1231,7 +1237,7 @@ pub struct StructExpression { impl Emit for StructExpression { fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> io::Result<()> { match t { - Target::Trino => { + Target::Trino(_) => { f.write_token_start("CAST(")?; self.struct_token.ident.token.with_str("ROW").emit(t, f)?; self.paren1.emit(t, f)?; @@ -1537,7 +1543,7 @@ pub enum DataType { impl Emit for DataType { fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> io::Result<()> { match t { - Target::Trino => match self { + Target::Trino(_) => match self { DataType::Bool(token) => token.ident.token.with_str("BOOLEAN").emit(t, f), DataType::Bytes(token) => token.ident.token.with_str("VARBINARY").emit(t, f), DataType::Date(token) => token.emit(t, f), @@ -1660,7 +1666,7 @@ impl Emit for FromItem { FromItem { table_expression: table_expression @ FromTableExpression::Unnest(..), alias: Some(Alias { as_token, ident }), - } if t == Target::Trino => { + } if matches!(t, Target::Trino(_)) => { table_expression.emit(t, f)?; as_token.emit(t, f)?; // UNNEST aliases aren't like other aliases, and Trino treats diff --git a/src/cmd/sql_test.rs b/src/cmd/sql_test.rs index 22082ee..1c9f9df 100644 --- a/src/cmd/sql_test.rs +++ b/src/cmd/sql_test.rs @@ -77,7 +77,7 @@ pub async fn cmd_sql_test(files: &mut KnownFiles, opt: &SqlTestOpt) -> Result<() if !opt.pending { let short_path = path.strip_prefix(&base_dir).unwrap_or(&path); if let Some(pending_test_info) = - PendingTestInfo::for_target(locator.target(), short_path, sql) + PendingTestInfo::for_target(locator.target().await?, short_path, sql) { progress('P'); pending.push(pending_test_info); diff --git a/src/cmd/transpile.rs b/src/cmd/transpile.rs index 39a6146..24736d4 100644 --- a/src/cmd/transpile.rs +++ b/src/cmd/transpile.rs @@ -51,7 +51,7 @@ pub async fn cmd_transpile(files: &mut KnownFiles, opt: &TranspileOpt) -> Result for statement in rewritten_ast.extra.native_setup_sql { println!("{};", statement); } - let transpiled_sql = rewritten_ast.ast.emit_to_string(locator.target()); + let transpiled_sql = rewritten_ast.ast.emit_to_string(locator.target().await?); println!("{}", transpiled_sql); for statement in rewritten_ast.extra.native_teardown_sql { println!("{};", statement); diff --git a/src/drivers/mod.rs b/src/drivers/mod.rs index 42258ed..bccb006 100644 --- a/src/drivers/mod.rs +++ b/src/drivers/mod.rs @@ -22,8 +22,9 @@ pub mod trino; /// A URL-like locator for a database. #[async_trait] pub trait Locator: fmt::Display + fmt::Debug + Send + Sync + 'static { - /// Get the target for this locator. - fn target(&self) -> Target; + /// Get the target for this locator. For some databases, this might need to + /// query the database to get details about the target. + async fn target(&self) -> Result; /// Get the driver for this locator. async fn driver(&self) -> Result>; diff --git a/src/drivers/trino/mod.rs b/src/drivers/trino/mod.rs index 6ac9379..57c6953 100644 --- a/src/drivers/trino/mod.rs +++ b/src/drivers/trino/mod.rs @@ -5,7 +5,7 @@ use std::{fmt, str::FromStr, time::Duration}; use async_trait::async_trait; use dbcrossbar_trino::{ client::{Client, ClientBuilder, ClientError, QueryError}, - DataType, Ident as TrinoIdent, Value, + ConnectorType as TrinoConnectorType, DataType, Ident as TrinoIdent, Value, }; use joinery_macros::sql_quote; use once_cell::sync::Lazy; @@ -105,6 +105,29 @@ pub struct TrinoLocator { schema: String, } +impl TrinoLocator { + /// Create a client for this locator. + fn client(&self) -> Client { + ClientBuilder::new( + self.user.clone(), + self.host.clone(), + self.port.unwrap_or(8080), + ) + .catalog_and_schema(self.catalog.clone(), self.schema.clone()) + .build() + } + + /// Get our Trino connector type. + async fn connector_type(&self) -> Result { + let client = self.client(); + let catalog = TrinoIdent::new(&self.catalog).map_err(Error::other)?; + client + .catalog_connector_type(&catalog) + .await + .map_err(Error::other) + } +} + impl fmt::Display for TrinoLocator { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "trino://{}@{}", self.user, self.host,)?; @@ -152,17 +175,19 @@ impl FromStr for TrinoLocator { #[async_trait] impl Locator for TrinoLocator { - fn target(&self) -> Target { - Target::Trino + async fn target(&self) -> Result { + let connector_type = self.connector_type().await?; + Ok(Target::Trino(connector_type)) } async fn driver(&self) -> Result> { - Ok(Box::new(TrinoDriver::from_locator(self)?)) + Ok(Box::new(TrinoDriver::from_locator(self).await?)) } } /// A Trino driver. pub struct TrinoDriver { + connector_type: TrinoConnectorType, catalog: String, schema: String, client: Client, @@ -170,19 +195,14 @@ pub struct TrinoDriver { impl TrinoDriver { /// Create a new Trino driver from a locator. - pub fn from_locator(locator: &TrinoLocator) -> Result { - let client = ClientBuilder::new( - locator.user.clone(), - locator.host.clone(), - locator.port.unwrap_or(8080), - ) - .catalog_and_schema(locator.catalog.clone(), locator.schema.clone()) - .build(); - + pub async fn from_locator(locator: &TrinoLocator) -> Result { + let connector_type = locator.connector_type().await?; + let client = locator.client(); Ok(Self { - client, + connector_type, catalog: locator.catalog.clone(), schema: locator.schema.clone(), + client, }) } } @@ -190,7 +210,7 @@ impl TrinoDriver { #[async_trait] impl Driver for TrinoDriver { fn target(&self) -> Target { - Target::Trino + Target::Trino(self.connector_type) } #[tracing::instrument(skip_all)]