Skip to content

Commit

Permalink
Add ConnectorType to Target::Trino
Browse files Browse the repository at this point in the history
We will need to distinguish between Trino connectors like `memory`,
`hive` and `iceberg` when emitting Trino SQL, because the different
connectors support subsets of Trino's features and types.

We modify the `Target` enum to include a `Target::Trino(ConnectionType)`
variant instead of just `Target`. This will allow `emit` code to know
what target it is generating code for.

This also removes the final pretense of a "standalone" mode where we
can run without access to a specific Trino server. This is fine; we're
moving in that direction anyway.

Note the replacement of `t == Target::Trino` with `matches!(t,
Target::Trino(_))`, which is a handy Rust feature.
  • Loading branch information
emk committed Dec 6, 2024
1 parent 8438cce commit 625dcbe
Show file tree
Hide file tree
Showing 8 changed files with 72 additions and 36 deletions.
4 changes: 2 additions & 2 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

7 changes: 6 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,12 @@ async-trait = "0.1.73"
clap = { version = "4.4.6", features = ["derive", "wrap_help"] }
codespan-reporting = "0.11.1"
csv = "1.2.2"
dbcrossbar_trino = { version = "0.2.2", features = ["macros", "values", "client", "rustls-tls"] }
dbcrossbar_trino = { version = "0.2.3", features = [
"macros",
"values",
"client",
"rustls-tls",
] }
derive-visitor = "0.4.0"
glob = "0.3.1"
joinery_macros = { path = "joinery_macros" }
Expand Down
10 changes: 7 additions & 3 deletions Justfile
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ trino-image: trino-plugin
# Run our Trino image.
docker-run-trino: trino-image
(cd trino && docker compose up)
docker compose exec trino install-udfs
(cd trino && docker compose exec trino install-udfs)

# Stop and delete our Trino container.
docker-rm-trino:
Expand All @@ -33,10 +33,14 @@ check:
#cargo deny check
cargo test

# Check Trino. Assumes `docker-run-trino` has been run.
# Check Trino (memory). Assumes `docker-run-trino` has been run.
check-trino:
cargo run -- sql-test --database "trino://admin@localhost/memory/default" ./tests/sql/

# Check Trino (Hive). Assumes `docker-run-trino` has been run.
check-trino-hive:
cargo run -- sql-test --database "trino://admin@localhost/hive/default" ./tests/sql/

# Access a Trino shell.
trino-shell:
docker compose exec trino-joinery trino
(cd trino && docker compose exec trino-joinery trino)
28 changes: 17 additions & 11 deletions src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ use std::{
mem::take,
};

use dbcrossbar_trino::ConnectorType as TrinoConnectorType;
use derive_visitor::{Drive, DriveMut};
use joinery_macros::{Emit, EmitDefault, Spanned, ToTokens};

Expand Down Expand Up @@ -65,16 +66,21 @@ static KEYWORDS: phf::Set<&'static str> = phf::phf_set! {
#[derive(Clone, Copy, Debug, Eq, PartialEq)]
#[allow(dead_code)]
pub enum Target {
/// We're transpiling BigQuery to BigQuery, which is mostly useful for
/// dumping the AST back out as SQL.
BigQuery,
Trino,
/// We're transpiling BigQuery to Trino. Note that we need to specify what
/// kind of Trino connector we're using, because many connectors are missing
/// certain data types, and we need to fix the SQL we emit accordingly.
Trino(TrinoConnectorType),
}

impl Target {
/// Is the specified string a keyword?
pub fn is_keyword(self, s: &str) -> bool {
let keywords = match self {
Target::BigQuery => &KEYWORDS,
Target::Trino => &TRINO_KEYWORDS,
Target::Trino(_) => &TRINO_KEYWORDS,
};
keywords.contains(s.to_ascii_uppercase().as_str())
}
Expand All @@ -84,7 +90,7 @@ impl fmt::Display for Target {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Target::BigQuery => write!(f, "bigquery"),
Target::Trino => write!(f, "trino"),
Target::Trino(_) => write!(f, "trino"),
}
}
}
Expand Down Expand Up @@ -187,7 +193,7 @@ impl Emit for Ident {
if t.is_keyword(&self.name) || !is_c_ident(&self.name) {
match t {
Target::BigQuery => write!(f, "{}", BigQueryName(&self.name))?,
Target::Trino => {
Target::Trino(_) => {
write!(f, "{}", AnsiIdent(&self.name))?;
}
}
Expand Down Expand Up @@ -232,7 +238,7 @@ impl Emit for LiteralValue {
LiteralValue::Float64(fl) => write!(f, "{}", fl),
LiteralValue::String(s) => match t {
Target::BigQuery => write!(f, "{}", BigQueryString(s)),
Target::Trino => write!(f, "{}", TrinoString(s)),
Target::Trino(_) => write!(f, "{}", TrinoString(s)),
},
}
}
Expand Down Expand Up @@ -914,7 +920,7 @@ pub enum CastType {
impl Emit for CastType {
fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> io::Result<()> {
match self {
CastType::SafeCast { safe_cast_token } if t == Target::Trino => {
CastType::SafeCast { safe_cast_token } if matches!(t, Target::Trino(_)) => {
safe_cast_token.ident.token.with_str("TRY_CAST").emit(t, f)
}
_ => self.emit_default(t, f),
Expand Down Expand Up @@ -1080,7 +1086,7 @@ impl Emit for ArrayExpression {
definition: ArrayDefinition::Query { select },
delim2,
..
} if t == Target::Trino => {
} if matches!(t, Target::Trino(_)) => {
// We can't do this with a transform and sql_quote because it
// outputs Trino-specific closure syntax.
let ArraySelectExpression {
Expand Down Expand Up @@ -1130,7 +1136,7 @@ impl Emit for ArrayExpression {
last_token.with_ws_only().emit(t, f)?;
}
_ => match t {
Target::Trino => {
Target::Trino(_) => {
let needs_cast = self.definition.has_zero_element_expressions();
if needs_cast {
f.write_token_start("CAST(")?;
Expand Down Expand Up @@ -1231,7 +1237,7 @@ pub struct StructExpression {
impl Emit for StructExpression {
fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> io::Result<()> {
match t {
Target::Trino => {
Target::Trino(_) => {
f.write_token_start("CAST(")?;
self.struct_token.ident.token.with_str("ROW").emit(t, f)?;
self.paren1.emit(t, f)?;
Expand Down Expand Up @@ -1537,7 +1543,7 @@ pub enum DataType {
impl Emit for DataType {
fn emit(&self, t: Target, f: &mut TokenWriter<'_>) -> io::Result<()> {
match t {
Target::Trino => match self {
Target::Trino(_) => match self {
DataType::Bool(token) => token.ident.token.with_str("BOOLEAN").emit(t, f),
DataType::Bytes(token) => token.ident.token.with_str("VARBINARY").emit(t, f),
DataType::Date(token) => token.emit(t, f),
Expand Down Expand Up @@ -1660,7 +1666,7 @@ impl Emit for FromItem {
FromItem {
table_expression: table_expression @ FromTableExpression::Unnest(..),
alias: Some(Alias { as_token, ident }),
} if t == Target::Trino => {
} if matches!(t, Target::Trino(_)) => {
table_expression.emit(t, f)?;
as_token.emit(t, f)?;
// UNNEST aliases aren't like other aliases, and Trino treats
Expand Down
2 changes: 1 addition & 1 deletion src/cmd/sql_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ pub async fn cmd_sql_test(files: &mut KnownFiles, opt: &SqlTestOpt) -> Result<()
if !opt.pending {
let short_path = path.strip_prefix(&base_dir).unwrap_or(&path);
if let Some(pending_test_info) =
PendingTestInfo::for_target(locator.target(), short_path, sql)
PendingTestInfo::for_target(locator.target().await?, short_path, sql)
{
progress('P');
pending.push(pending_test_info);
Expand Down
2 changes: 1 addition & 1 deletion src/cmd/transpile.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ pub async fn cmd_transpile(files: &mut KnownFiles, opt: &TranspileOpt) -> Result
for statement in rewritten_ast.extra.native_setup_sql {
println!("{};", statement);
}
let transpiled_sql = rewritten_ast.ast.emit_to_string(locator.target());
let transpiled_sql = rewritten_ast.ast.emit_to_string(locator.target().await?);
println!("{}", transpiled_sql);
for statement in rewritten_ast.extra.native_teardown_sql {
println!("{};", statement);
Expand Down
5 changes: 3 additions & 2 deletions src/drivers/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,9 @@ pub mod trino;
/// A URL-like locator for a database.
#[async_trait]
pub trait Locator: fmt::Display + fmt::Debug + Send + Sync + 'static {
/// Get the target for this locator.
fn target(&self) -> Target;
/// Get the target for this locator. For some databases, this might need to
/// query the database to get details about the target.
async fn target(&self) -> Result<Target>;

/// Get the driver for this locator.
async fn driver(&self) -> Result<Box<dyn Driver>>;
Expand Down
50 changes: 35 additions & 15 deletions src/drivers/trino/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::{fmt, str::FromStr, time::Duration};
use async_trait::async_trait;
use dbcrossbar_trino::{
client::{Client, ClientBuilder, ClientError, QueryError},
DataType, Ident as TrinoIdent, Value,
ConnectorType as TrinoConnectorType, DataType, Ident as TrinoIdent, Value,
};
use joinery_macros::sql_quote;
use once_cell::sync::Lazy;
Expand Down Expand Up @@ -105,6 +105,29 @@ pub struct TrinoLocator {
schema: String,
}

impl TrinoLocator {
/// Create a client for this locator.
fn client(&self) -> Client {
ClientBuilder::new(
self.user.clone(),
self.host.clone(),
self.port.unwrap_or(8080),
)
.catalog_and_schema(self.catalog.clone(), self.schema.clone())
.build()
}

/// Get our Trino connector type.
async fn connector_type(&self) -> Result<TrinoConnectorType> {
let client = self.client();
let catalog = TrinoIdent::new(&self.catalog).map_err(Error::other)?;
client
.catalog_connector_type(&catalog)
.await
.map_err(Error::other)
}
}

impl fmt::Display for TrinoLocator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "trino://{}@{}", self.user, self.host,)?;
Expand Down Expand Up @@ -152,45 +175,42 @@ impl FromStr for TrinoLocator {

#[async_trait]
impl Locator for TrinoLocator {
fn target(&self) -> Target {
Target::Trino
async fn target(&self) -> Result<Target> {
let connector_type = self.connector_type().await?;
Ok(Target::Trino(connector_type))
}

async fn driver(&self) -> Result<Box<dyn Driver>> {
Ok(Box::new(TrinoDriver::from_locator(self)?))
Ok(Box::new(TrinoDriver::from_locator(self).await?))
}
}

/// A Trino driver.
pub struct TrinoDriver {
connector_type: TrinoConnectorType,
catalog: String,
schema: String,
client: Client,
}

impl TrinoDriver {
/// Create a new Trino driver from a locator.
pub fn from_locator(locator: &TrinoLocator) -> Result<Self> {
let client = ClientBuilder::new(
locator.user.clone(),
locator.host.clone(),
locator.port.unwrap_or(8080),
)
.catalog_and_schema(locator.catalog.clone(), locator.schema.clone())
.build();

pub async fn from_locator(locator: &TrinoLocator) -> Result<Self> {
let connector_type = locator.connector_type().await?;
let client = locator.client();
Ok(Self {
client,
connector_type,
catalog: locator.catalog.clone(),
schema: locator.schema.clone(),
client,
})
}
}

#[async_trait]
impl Driver for TrinoDriver {
fn target(&self) -> Target {
Target::Trino
Target::Trino(self.connector_type)
}

#[tracing::instrument(skip_all)]
Expand Down

0 comments on commit 625dcbe

Please sign in to comment.