From 61ef25bc2664fb6fc65dbac86cf9909b4867f4ae Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 10 Nov 2023 13:49:23 +0100 Subject: [PATCH 01/34] feat(quaint): allow wasm32-unknown-unknown compilation; currently fails on native --- Cargo.lock | 1 + Cargo.toml | 1 + quaint/Cargo.toml | 41 +- quaint/src/connector.rs | 50 ++- quaint/src/connector/mssql.rs | 363 +-------------- quaint/src/connector/mssql_wasm.rs | 383 ++++++++++++++++ quaint/src/connector/mysql.rs | 297 +------------ quaint/src/connector/mysql_wasm.rs | 318 +++++++++++++ quaint/src/connector/postgres.rs | 423 +----------------- quaint/src/connector/postgres_wasm.rs | 612 ++++++++++++++++++++++++++ quaint/src/connector/sqlite.rs | 104 +---- quaint/src/connector/sqlite_wasm.rs | 103 +++++ quaint/src/error.rs | 6 +- quaint/src/pooled/manager.rs | 30 +- quaint/src/single.rs | 10 +- 15 files changed, 1533 insertions(+), 1209 deletions(-) create mode 100644 quaint/src/connector/mssql_wasm.rs create mode 100644 quaint/src/connector/mysql_wasm.rs create mode 100644 quaint/src/connector/postgres_wasm.rs create mode 100644 quaint/src/connector/sqlite_wasm.rs diff --git a/Cargo.lock b/Cargo.lock index 35eff530999a..ff8323e356e5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3570,6 +3570,7 @@ dependencies = [ "connection-string", "either", "futures", + "getrandom 0.2.10", "hex", "indoc 0.3.6", "lru-cache", diff --git a/Cargo.toml b/Cargo.toml index 4a3cd1450caf..66f4399ff6db 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,6 +68,7 @@ features = [ "pooled", "postgresql", "sqlite", + "connectors", ] [profile.dev.package.backtrace] diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index b699518d0910..2da9ec0929c0 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -23,20 +23,28 @@ resolver = "2" features = ["docs", "all"] [features] -default = [] +default = ["mysql", "postgresql", "mssql", "sqlite"] docs = [] # Expose the underlying database drivers when a connector is enabled. This is a # way to access database-specific methods when you need extra control. expose-drivers = [] -all = ["mssql", "mysql", "pooled", "postgresql", "sqlite"] +connectors = [ + "postgresql-connector", + "mysql-connector", + "mssql-connector", + "sqlite-connector", +] + +all = ["connectors", "pooled"] vendored-openssl = [ "postgres-native-tls/vendored-openssl", "mysql_async/vendored-openssl", ] -postgresql = [ +postgresql-connector = [ + "postgresql", "native-tls", "tokio-postgres", "postgres-types", @@ -47,11 +55,24 @@ postgresql = [ "lru-cache", "byteorder", ] +postgresql = [] + +mssql-connector = [ + "mssql", + "tiberius", + "tokio-util", + "tokio/time", + "tokio/net", +] +mssql = [] + +mysql-connector = ["mysql", "mysql_async", "tokio/time", "lru-cache"] +mysql = ["chrono/std"] -mssql = ["tiberius", "tokio-util", "tokio/time", "tokio/net", "either"] -mysql = ["mysql_async", "tokio/time", "lru-cache"] pooled = ["mobc"] -sqlite = ["rusqlite", "tokio/sync"] +sqlite-connector = ["sqlite", "rusqlite", "tokio/sync"] +sqlite = [] + fmt-sql = ["sqlformat"] [dependencies] @@ -67,7 +88,7 @@ futures = "0.3" url = "2.1" hex = "0.4" -either = { version = "1.6", optional = true } +either = { version = "1.6" } base64 = { version = "0.12.3" } chrono = { version = "0.4", default-features = false, features = ["serde"] } lru-cache = { version = "0.1", optional = true } @@ -88,7 +109,11 @@ paste = "1.0" serde = { version = "1.0", features = ["derive"] } quaint-test-macros = { path = "quaint-test-macros" } quaint-test-setup = { path = "quaint-test-setup" } -tokio = { version = "1.0", features = ["rt-multi-thread", "macros", "time"] } +tokio = { version = "1.0", features = ["macros", "time"] } + +[target.'cfg(target_arch = "wasm32")'.dependencies.getrandom] +version = "0.2" +features = ["js"] [dependencies.byteorder] default-features = false diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index de8bc64d22bb..898aac8fcb46 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -10,36 +10,62 @@ //! querying interface. mod connection_info; + pub mod metrics; mod queryable; mod result_set; -#[cfg(any(feature = "mssql", feature = "postgresql", feature = "mysql"))] +#[cfg(any( + feature = "mssql-connector", + feature = "postgresql-connector", + feature = "mysql-connector" +))] mod timeout; mod transaction; mod type_identifier; -#[cfg(feature = "mssql")] +#[cfg(feature = "mssql-connector")] pub(crate) mod mssql; -#[cfg(feature = "mysql")] +#[cfg(feature = "mssql")] +pub(crate) mod mssql_wasm; +#[cfg(feature = "mysql-connector")] pub(crate) mod mysql; -#[cfg(feature = "postgresql")] +#[cfg(feature = "mysql")] +pub(crate) mod mysql_wasm; +#[cfg(feature = "postgresql-connector")] pub(crate) mod postgres; -#[cfg(feature = "sqlite")] +#[cfg(feature = "postgresql")] +pub(crate) mod postgres_wasm; +#[cfg(feature = "sqlite-connector")] pub(crate) mod sqlite; +#[cfg(feature = "sqlite")] +pub(crate) mod sqlite_wasm; -#[cfg(feature = "mysql")] +#[cfg(feature = "mysql-connector")] pub use self::mysql::*; -#[cfg(feature = "postgresql")] +#[cfg(feature = "mysql")] +pub use self::mysql_wasm::*; +#[cfg(feature = "postgresql-connector")] pub use self::postgres::*; +#[cfg(feature = "postgresql")] +pub use self::postgres_wasm::*; +#[cfg(feature = "mssql-connector")] +pub use mssql::*; +#[cfg(feature = "mssql")] +pub use mssql_wasm::*; +#[cfg(feature = "sqlite-connector")] +pub use sqlite::*; +#[cfg(feature = "sqlite")] +pub use sqlite_wasm::*; + pub use self::result_set::*; pub use connection_info::*; -#[cfg(feature = "mssql")] -pub use mssql::*; pub use queryable::*; -#[cfg(feature = "sqlite")] -pub use sqlite::*; pub use transaction::*; -#[cfg(any(feature = "sqlite", feature = "mysql", feature = "postgresql"))] +#[cfg(any( + feature = "mssql-connector", + feature = "postgresql-connector", + feature = "mysql-connector" +))] #[allow(unused_imports)] pub(crate) use type_identifier::*; diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index cef092edb9d7..16c31551768c 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -1,21 +1,19 @@ mod conversion; mod error; +pub(crate) use super::mssql_wasm::MssqlUrl; use super::{IsolationLevel, Transaction, TransactionOptions}; use crate::{ ast::{Query, Value}, connector::{metrics, queryable::*, DefaultTransaction, ResultSet}, - error::{Error, ErrorKind}, visitor::{self, Visitor}, }; use async_trait::async_trait; -use connection_string::JdbcString; use futures::lock::Mutex; use std::{ convert::TryFrom, fmt, future::Future, - str::FromStr, sync::atomic::{AtomicBool, Ordering}, time::Duration, }; @@ -27,69 +25,6 @@ use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; #[cfg(feature = "expose-drivers")] pub use tiberius; -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug, Clone)] -pub struct MssqlUrl { - connection_string: String, - query_params: MssqlQueryParams, -} - -/// TLS mode when connecting to SQL Server. -#[derive(Debug, Clone, Copy)] -pub enum EncryptMode { - /// All traffic is encrypted. - On, - /// Only the login credentials are encrypted. - Off, - /// Nothing is encrypted. - DangerPlainText, -} - -impl fmt::Display for EncryptMode { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::On => write!(f, "true"), - Self::Off => write!(f, "false"), - Self::DangerPlainText => write!(f, "DANGER_PLAINTEXT"), - } - } -} - -impl FromStr for EncryptMode { - type Err = Error; - - fn from_str(s: &str) -> crate::Result { - let mode = match s.parse::() { - Ok(true) => Self::On, - _ if s == "DANGER_PLAINTEXT" => Self::DangerPlainText, - _ => Self::Off, - }; - - Ok(mode) - } -} - -#[derive(Debug, Clone)] -pub(crate) struct MssqlQueryParams { - encrypt: EncryptMode, - port: Option, - host: Option, - user: Option, - password: Option, - database: String, - schema: String, - trust_server_certificate: bool, - trust_server_certificate_ca: Option, - connection_limit: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - transaction_isolation_level: Option, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, -} - static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted; #[async_trait] @@ -114,158 +49,6 @@ impl TransactionCapable for Mssql { } } -impl MssqlUrl { - /// Maximum number of connections the pool can have (if used together with - /// pooled Quaint). - pub fn connection_limit(&self) -> Option { - self.query_params.connection_limit() - } - - /// A duration how long one query can take. - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout() - } - - /// A duration how long we can try to connect to the database. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout() - } - - /// A pool check_out timeout. - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout() - } - - /// The isolation level of a transaction. - fn transaction_isolation_level(&self) -> Option { - self.query_params.transaction_isolation_level - } - - /// Name of the database. - pub fn dbname(&self) -> &str { - self.query_params.database() - } - - /// The prefix which to use when querying database. - pub fn schema(&self) -> &str { - self.query_params.schema() - } - - /// Database hostname. - pub fn host(&self) -> &str { - self.query_params.host() - } - - /// The username to use when connecting to the database. - pub fn username(&self) -> Option<&str> { - self.query_params.user() - } - - /// The password to use when connecting to the database. - pub fn password(&self) -> Option<&str> { - self.query_params.password() - } - - /// The TLS mode to use when connecting to the database. - pub fn encrypt(&self) -> EncryptMode { - self.query_params.encrypt() - } - - /// If true, we allow invalid certificates (self-signed, or otherwise - /// dangerous) when connecting. Should be true only for development and - /// testing. - pub fn trust_server_certificate(&self) -> bool { - self.query_params.trust_server_certificate() - } - - /// Path to a custom server certificate file. - pub fn trust_server_certificate_ca(&self) -> Option<&str> { - self.query_params.trust_server_certificate_ca() - } - - /// Database port. - pub fn port(&self) -> u16 { - self.query_params.port() - } - - /// The JDBC connection string - pub fn connection_string(&self) -> &str { - &self.connection_string - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime() - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime() - } -} - -impl MssqlQueryParams { - fn port(&self) -> u16 { - self.port.unwrap_or(1433) - } - - fn host(&self) -> &str { - self.host.as_deref().unwrap_or("localhost") - } - - fn user(&self) -> Option<&str> { - self.user.as_deref() - } - - fn password(&self) -> Option<&str> { - self.password.as_deref() - } - - fn encrypt(&self) -> EncryptMode { - self.encrypt - } - - fn trust_server_certificate(&self) -> bool { - self.trust_server_certificate - } - - fn trust_server_certificate_ca(&self) -> Option<&str> { - self.trust_server_certificate_ca.as_deref() - } - - fn database(&self) -> &str { - &self.database - } - - fn schema(&self) -> &str { - &self.schema - } - - fn socket_timeout(&self) -> Option { - self.socket_timeout - } - - fn connect_timeout(&self) -> Option { - self.connect_timeout - } - - fn connection_limit(&self) -> Option { - self.connection_limit - } - - fn pool_timeout(&self) -> Option { - self.pool_timeout - } - - fn max_connection_lifetime(&self) -> Option { - self.max_connection_lifetime - } - - fn max_idle_connection_lifetime(&self) -> Option { - self.max_idle_connection_lifetime - } -} - /// A connector interface for the SQL Server database. #[derive(Debug)] pub struct Mssql { @@ -452,150 +235,6 @@ impl Queryable for Mssql { } } -impl MssqlUrl { - pub fn new(jdbc_connection_string: &str) -> crate::Result { - let query_params = Self::parse_query_params(jdbc_connection_string)?; - let connection_string = Self::with_jdbc_prefix(jdbc_connection_string); - - Ok(Self { - connection_string, - query_params, - }) - } - - fn with_jdbc_prefix(input: &str) -> String { - if input.starts_with("jdbc:sqlserver") { - input.into() - } else { - format!("jdbc:{input}") - } - } - - fn parse_query_params(input: &str) -> crate::Result { - let mut conn = JdbcString::from_str(&Self::with_jdbc_prefix(input))?; - - let host = conn.server_name().map(|server_name| match conn.instance_name() { - Some(instance_name) => format!(r#"{server_name}\{instance_name}"#), - None => server_name.to_string(), - }); - - let port = conn.port(); - let props = conn.properties_mut(); - let user = props.remove("user"); - let password = props.remove("password"); - let database = props.remove("database").unwrap_or_else(|| String::from("master")); - let schema = props.remove("schema").unwrap_or_else(|| String::from("dbo")); - - let connection_limit = props - .remove("connectionlimit") - .or_else(|| props.remove("connection_limit")) - .map(|param| param.parse()) - .transpose()?; - - let transaction_isolation_level = props - .remove("isolationlevel") - .or_else(|| props.remove("isolation_level")) - .map(|level| { - IsolationLevel::from_str(&level).map_err(|_| { - let kind = ErrorKind::database_url_is_invalid(format!("Invalid isolation level `{level}`")); - Error::builder(kind).build() - }) - }) - .transpose()?; - - let mut connect_timeout = props - .remove("logintimeout") - .or_else(|| props.remove("login_timeout")) - .or_else(|| props.remove("connecttimeout")) - .or_else(|| props.remove("connect_timeout")) - .or_else(|| props.remove("connectiontimeout")) - .or_else(|| props.remove("connection_timeout")) - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match connect_timeout { - None => connect_timeout = Some(Duration::from_secs(5)), - Some(dur) if dur.as_secs() == 0 => connect_timeout = None, - _ => (), - } - - let mut pool_timeout = props - .remove("pooltimeout") - .or_else(|| props.remove("pool_timeout")) - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match pool_timeout { - None => pool_timeout = Some(Duration::from_secs(10)), - Some(dur) if dur.as_secs() == 0 => pool_timeout = None, - _ => (), - } - - let socket_timeout = props - .remove("sockettimeout") - .or_else(|| props.remove("socket_timeout")) - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - let encrypt = props - .remove("encrypt") - .map(|param| EncryptMode::from_str(¶m)) - .transpose()? - .unwrap_or(EncryptMode::On); - - let trust_server_certificate = props - .remove("trustservercertificate") - .or_else(|| props.remove("trust_server_certificate")) - .map(|param| param.parse()) - .transpose()? - .unwrap_or(false); - - let trust_server_certificate_ca: Option = props - .remove("trustservercertificateca") - .or_else(|| props.remove("trust_server_certificate_ca")); - - let mut max_connection_lifetime = props - .remove("max_connection_lifetime") - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match max_connection_lifetime { - Some(dur) if dur.as_secs() == 0 => max_connection_lifetime = None, - _ => (), - } - - let mut max_idle_connection_lifetime = props - .remove("max_idle_connection_lifetime") - .map(|param| param.parse().map(Duration::from_secs)) - .transpose()?; - - match max_idle_connection_lifetime { - None => max_idle_connection_lifetime = Some(Duration::from_secs(300)), - Some(dur) if dur.as_secs() == 0 => max_idle_connection_lifetime = None, - _ => (), - } - - Ok(MssqlQueryParams { - encrypt, - port, - host, - user, - password, - database, - schema, - trust_server_certificate, - trust_server_certificate_ca, - connection_limit, - socket_timeout, - connect_timeout, - pool_timeout, - transaction_isolation_level, - max_connection_lifetime, - max_idle_connection_lifetime, - }) - } -} - #[cfg(test)] mod tests { use crate::tests::test_api::mssql::CONN_STR; diff --git a/quaint/src/connector/mssql_wasm.rs b/quaint/src/connector/mssql_wasm.rs new file mode 100644 index 000000000000..d9f7dc27865b --- /dev/null +++ b/quaint/src/connector/mssql_wasm.rs @@ -0,0 +1,383 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + +use super::IsolationLevel; + +use crate::error::{Error, ErrorKind}; +use connection_string::JdbcString; +use std::{fmt, str::FromStr, time::Duration}; + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug, Clone)] +pub struct MssqlUrl { + pub(super) connection_string: String, + pub(super) query_params: MssqlQueryParams, +} + +/// TLS mode when connecting to SQL Server. +#[derive(Debug, Clone, Copy)] +pub enum EncryptMode { + /// All traffic is encrypted. + On, + /// Only the login credentials are encrypted. + Off, + /// Nothing is encrypted. + DangerPlainText, +} + +impl fmt::Display for EncryptMode { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::On => write!(f, "true"), + Self::Off => write!(f, "false"), + Self::DangerPlainText => write!(f, "DANGER_PLAINTEXT"), + } + } +} + +impl FromStr for EncryptMode { + type Err = Error; + + fn from_str(s: &str) -> crate::Result { + let mode = match s.parse::() { + Ok(true) => Self::On, + _ if s == "DANGER_PLAINTEXT" => Self::DangerPlainText, + _ => Self::Off, + }; + + Ok(mode) + } +} + +#[derive(Debug, Clone)] +pub(crate) struct MssqlQueryParams { + pub(super) encrypt: EncryptMode, + pub(super) port: Option, + pub(super) host: Option, + pub(super) user: Option, + pub(super) password: Option, + pub(super) database: String, + pub(super) schema: String, + pub(super) trust_server_certificate: bool, + pub(super) trust_server_certificate_ca: Option, + pub(super) connection_limit: Option, + pub(super) socket_timeout: Option, + pub(super) connect_timeout: Option, + pub(super) pool_timeout: Option, + pub(super) transaction_isolation_level: Option, + pub(super) max_connection_lifetime: Option, + pub(super) max_idle_connection_lifetime: Option, +} + +impl MssqlUrl { + /// Maximum number of connections the pool can have (if used together with + /// pooled Quaint). + pub fn connection_limit(&self) -> Option { + self.query_params.connection_limit() + } + + /// A duration how long one query can take. + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout() + } + + /// A duration how long we can try to connect to the database. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout() + } + + /// A pool check_out timeout. + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout() + } + + /// The isolation level of a transaction. + pub(crate) fn transaction_isolation_level(&self) -> Option { + self.query_params.transaction_isolation_level + } + + /// Name of the database. + pub fn dbname(&self) -> &str { + self.query_params.database() + } + + /// The prefix which to use when querying database. + pub fn schema(&self) -> &str { + self.query_params.schema() + } + + /// Database hostname. + pub fn host(&self) -> &str { + self.query_params.host() + } + + /// The username to use when connecting to the database. + pub fn username(&self) -> Option<&str> { + self.query_params.user() + } + + /// The password to use when connecting to the database. + pub fn password(&self) -> Option<&str> { + self.query_params.password() + } + + /// The TLS mode to use when connecting to the database. + pub fn encrypt(&self) -> EncryptMode { + self.query_params.encrypt() + } + + /// If true, we allow invalid certificates (self-signed, or otherwise + /// dangerous) when connecting. Should be true only for development and + /// testing. + pub fn trust_server_certificate(&self) -> bool { + self.query_params.trust_server_certificate() + } + + /// Path to a custom server certificate file. + pub fn trust_server_certificate_ca(&self) -> Option<&str> { + self.query_params.trust_server_certificate_ca() + } + + /// Database port. + pub fn port(&self) -> u16 { + self.query_params.port() + } + + /// The JDBC connection string + pub fn connection_string(&self) -> &str { + &self.connection_string + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime() + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime() + } +} + +impl MssqlQueryParams { + fn port(&self) -> u16 { + self.port.unwrap_or(1433) + } + + fn host(&self) -> &str { + self.host.as_deref().unwrap_or("localhost") + } + + fn user(&self) -> Option<&str> { + self.user.as_deref() + } + + fn password(&self) -> Option<&str> { + self.password.as_deref() + } + + fn encrypt(&self) -> EncryptMode { + self.encrypt + } + + fn trust_server_certificate(&self) -> bool { + self.trust_server_certificate + } + + fn trust_server_certificate_ca(&self) -> Option<&str> { + self.trust_server_certificate_ca.as_deref() + } + + fn database(&self) -> &str { + &self.database + } + + fn schema(&self) -> &str { + &self.schema + } + + fn socket_timeout(&self) -> Option { + self.socket_timeout + } + + fn connect_timeout(&self) -> Option { + self.connect_timeout + } + + fn connection_limit(&self) -> Option { + self.connection_limit + } + + fn pool_timeout(&self) -> Option { + self.pool_timeout + } + + fn max_connection_lifetime(&self) -> Option { + self.max_connection_lifetime + } + + fn max_idle_connection_lifetime(&self) -> Option { + self.max_idle_connection_lifetime + } +} + +impl MssqlUrl { + pub fn new(jdbc_connection_string: &str) -> crate::Result { + let query_params = Self::parse_query_params(jdbc_connection_string)?; + let connection_string = Self::with_jdbc_prefix(jdbc_connection_string); + + Ok(Self { + connection_string, + query_params, + }) + } + + fn with_jdbc_prefix(input: &str) -> String { + if input.starts_with("jdbc:sqlserver") { + input.into() + } else { + format!("jdbc:{input}") + } + } + + fn parse_query_params(input: &str) -> crate::Result { + let mut conn = JdbcString::from_str(&Self::with_jdbc_prefix(input))?; + + let host = conn.server_name().map(|server_name| match conn.instance_name() { + Some(instance_name) => format!(r#"{server_name}\{instance_name}"#), + None => server_name.to_string(), + }); + + let port = conn.port(); + let props = conn.properties_mut(); + let user = props.remove("user"); + let password = props.remove("password"); + let database = props.remove("database").unwrap_or_else(|| String::from("master")); + let schema = props.remove("schema").unwrap_or_else(|| String::from("dbo")); + + let connection_limit = props + .remove("connectionlimit") + .or_else(|| props.remove("connection_limit")) + .map(|param| param.parse()) + .transpose()?; + + let transaction_isolation_level = props + .remove("isolationlevel") + .or_else(|| props.remove("isolation_level")) + .map(|level| { + IsolationLevel::from_str(&level).map_err(|_| { + let kind = ErrorKind::database_url_is_invalid(format!("Invalid isolation level `{level}`")); + Error::builder(kind).build() + }) + }) + .transpose()?; + + let mut connect_timeout = props + .remove("logintimeout") + .or_else(|| props.remove("login_timeout")) + .or_else(|| props.remove("connecttimeout")) + .or_else(|| props.remove("connect_timeout")) + .or_else(|| props.remove("connectiontimeout")) + .or_else(|| props.remove("connection_timeout")) + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match connect_timeout { + None => connect_timeout = Some(Duration::from_secs(5)), + Some(dur) if dur.as_secs() == 0 => connect_timeout = None, + _ => (), + } + + let mut pool_timeout = props + .remove("pooltimeout") + .or_else(|| props.remove("pool_timeout")) + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match pool_timeout { + None => pool_timeout = Some(Duration::from_secs(10)), + Some(dur) if dur.as_secs() == 0 => pool_timeout = None, + _ => (), + } + + let socket_timeout = props + .remove("sockettimeout") + .or_else(|| props.remove("socket_timeout")) + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + let encrypt = props + .remove("encrypt") + .map(|param| EncryptMode::from_str(¶m)) + .transpose()? + .unwrap_or(EncryptMode::On); + + let trust_server_certificate = props + .remove("trustservercertificate") + .or_else(|| props.remove("trust_server_certificate")) + .map(|param| param.parse()) + .transpose()? + .unwrap_or(false); + + let trust_server_certificate_ca: Option = props + .remove("trustservercertificateca") + .or_else(|| props.remove("trust_server_certificate_ca")); + + let mut max_connection_lifetime = props + .remove("max_connection_lifetime") + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match max_connection_lifetime { + Some(dur) if dur.as_secs() == 0 => max_connection_lifetime = None, + _ => (), + } + + let mut max_idle_connection_lifetime = props + .remove("max_idle_connection_lifetime") + .map(|param| param.parse().map(Duration::from_secs)) + .transpose()?; + + match max_idle_connection_lifetime { + None => max_idle_connection_lifetime = Some(Duration::from_secs(300)), + Some(dur) if dur.as_secs() == 0 => max_idle_connection_lifetime = None, + _ => (), + } + + Ok(MssqlQueryParams { + encrypt, + port, + host, + user, + password, + database, + schema, + trust_server_certificate, + trust_server_certificate_ca, + connection_limit, + socket_timeout, + connect_timeout, + pool_timeout, + transaction_isolation_level, + max_connection_lifetime, + max_idle_connection_lifetime, + }) + } +} + +#[cfg(test)] +mod tests { + use crate::tests::test_api::mssql::CONN_STR; + use crate::{error::*, single::Quaint}; + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let url = CONN_STR.replace("user=SA", "user=WRONG"); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } +} diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index 4b6f27a583da..a9a829404e76 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -1,6 +1,7 @@ mod conversion; mod error; +pub(crate) use super::mysql_wasm::MysqlUrl; use crate::{ ast::{Query, Value}, connector::{metrics, queryable::*, ResultSet}, @@ -13,16 +14,12 @@ use mysql_async::{ self as my, prelude::{Query as _, Queryable as _}, }; -use percent_encoding::percent_decode; use std::{ - borrow::Cow, future::Future, - path::{Path, PathBuf}, sync::atomic::{AtomicBool, Ordering}, time::Duration, }; use tokio::sync::Mutex; -use url::{Host, Url}; pub use error::MysqlError; @@ -33,293 +30,11 @@ pub use mysql_async; use super::IsolationLevel; -/// A connector interface for the MySQL database. -#[derive(Debug)] -pub struct Mysql { - pub(crate) conn: Mutex, - pub(crate) url: MysqlUrl, - socket_timeout: Option, - is_healthy: AtomicBool, - statement_cache: Mutex>, -} - -/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. -#[derive(Debug, Clone)] -pub struct MysqlUrl { - url: Url, - query_params: MysqlUrlQueryParams, -} - impl MysqlUrl { - /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection - /// parameters. - pub fn new(url: Url) -> Result { - let query_params = Self::parse_query_params(&url)?; - - Ok(Self { url, query_params }) - } - - /// The bare `Url` to the database. - pub fn url(&self) -> &Url { - &self.url - } - - /// The percent-decoded database username. - pub fn username(&self) -> Cow { - match percent_decode(self.url.username().as_bytes()).decode_utf8() { - Ok(username) => username, - Err(_) => { - tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - - self.url.username().into() - } - } - } - - /// The percent-decoded database password. - pub fn password(&self) -> Option> { - match self - .url - .password() - .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) - { - Some(password) => Some(password), - None => self.url.password().map(|s| s.into()), - } - } - - /// Name of the database connected. Defaults to `mysql`. - pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("mysql"), - None => "mysql", - } - } - - /// The database host. If `socket` and `host` are not set, defaults to `localhost`. - pub fn host(&self) -> &str { - match (self.url.host(), self.url.host_str()) { - (Some(Host::Ipv6(_)), Some(host)) => { - // The `url` crate may return an IPv6 address in brackets, which must be stripped. - if host.starts_with('[') && host.ends_with(']') { - &host[1..host.len() - 1] - } else { - host - } - } - (_, Some(host)) => host, - _ => "localhost", - } - } - - /// If set, connected to the database through a Unix socket. - pub fn socket(&self) -> &Option { - &self.query_params.socket - } - - /// The database port, defaults to `3306`. - pub fn port(&self) -> u16 { - self.url.port().unwrap_or(3306) - } - - /// The connection timeout. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout - } - - /// The pool check_out timeout - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout - } - - /// The socket timeout - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout - } - - /// Prefer socket connection - pub fn prefer_socket(&self) -> Option { - self.query_params.prefer_socket - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime - } - - fn statement_cache_size(&self) -> usize { - self.query_params.statement_cache_size - } - pub(crate) fn cache(&self) -> LruCache { LruCache::new(self.query_params.statement_cache_size) } - fn parse_query_params(url: &Url) -> Result { - let mut ssl_opts = my::SslOpts::default(); - ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true); - - let mut connection_limit = None; - let mut use_ssl = false; - let mut socket = None; - let mut socket_timeout = None; - let mut connect_timeout = Some(Duration::from_secs(5)); - let mut pool_timeout = Some(Duration::from_secs(10)); - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); - let mut prefer_socket = None; - let mut statement_cache_size = 100; - let mut identity: Option<(Option, Option)> = None; - - for (k, v) in url.query_pairs() { - match k.as_ref() { - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connection_limit = Some(as_int); - } - "statement_cache_size" => { - statement_cache_size = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslcert" => { - use_ssl = true; - ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf())); - } - "sslidentity" => { - use_ssl = true; - - identity = match identity { - Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)), - None => Some((Some(Path::new(&*v).to_path_buf()), None)), - }; - } - "sslpassword" => { - use_ssl = true; - - identity = match identity { - Some((path, _)) => Some((path, Some(v.to_string()))), - None => Some((None, Some(v.to_string()))), - }; - } - "socket" => { - socket = Some(v.replace(['(', ')'], "")); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - socket_timeout = Some(Duration::from_secs(as_int)); - } - "prefer_socket" => { - let as_bool = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - prefer_socket = Some(as_bool) - } - "connect_timeout" => { - let as_int = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connect_timeout = match as_int { - 0 => None, - _ => Some(Duration::from_secs(as_int)), - }; - } - "pool_timeout" => { - let as_int = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - pool_timeout = match as_int { - 0 => None, - _ => Some(Duration::from_secs(as_int)), - }; - } - "sslaccept" => { - use_ssl = true; - match v.as_ref() { - "strict" => { - ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false); - } - "accept_invalid_certs" => {} - _ => { - tracing::debug!( - message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`", - mode = &*v - ); - } - }; - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = &*k); - } - }; - } - - ssl_opts = match identity { - Some((Some(path), Some(pw))) => { - let identity = mysql_async::ClientIdentity::new(path).with_password(pw); - ssl_opts.with_client_identity(Some(identity)) - } - Some((Some(path), None)) => { - let identity = mysql_async::ClientIdentity::new(path); - ssl_opts.with_client_identity(Some(identity)) - } - _ => ssl_opts, - }; - - Ok(MysqlUrlQueryParams { - ssl_opts, - connection_limit, - use_ssl, - socket, - socket_timeout, - connect_timeout, - pool_timeout, - max_connection_lifetime, - max_idle_connection_lifetime, - prefer_socket, - statement_cache_size, - }) - } - - #[cfg(feature = "pooled")] - pub(crate) fn connection_limit(&self) -> Option { - self.query_params.connection_limit - } - pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder { let mut config = my::OptsBuilder::default() .stmt_cache_size(Some(0)) @@ -365,6 +80,16 @@ pub(crate) struct MysqlUrlQueryParams { statement_cache_size: usize, } +/// A connector interface for the MySQL database. +#[derive(Debug)] +pub struct Mysql { + pub(crate) conn: Mutex, + pub(crate) url: MysqlUrl, + socket_timeout: Option, + is_healthy: AtomicBool, + statement_cache: Mutex>, +} + impl Mysql { /// Create a new MySQL connection using `OptsBuilder` from the `mysql` crate. pub async fn new(url: MysqlUrl) -> crate::Result { diff --git a/quaint/src/connector/mysql_wasm.rs b/quaint/src/connector/mysql_wasm.rs new file mode 100644 index 000000000000..24cd525fea33 --- /dev/null +++ b/quaint/src/connector/mysql_wasm.rs @@ -0,0 +1,318 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + +use crate::error::{Error, ErrorKind}; +use percent_encoding::percent_decode; +use std::{ + borrow::Cow, + path::{Path, PathBuf}, + time::Duration, +}; +use url::{Host, Url}; + +/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. +#[derive(Debug, Clone)] +pub struct MysqlUrl { + url: Url, + pub(super) query_params: MysqlUrlQueryParams, +} + +impl MysqlUrl { + /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection + /// parameters. + pub fn new(url: Url) -> Result { + let query_params = Self::parse_query_params(&url)?; + + Ok(Self { url, query_params }) + } + + /// The bare `Url` to the database. + pub fn url(&self) -> &Url { + &self.url + } + + /// The percent-decoded database username. + pub fn username(&self) -> Cow { + match percent_decode(self.url.username().as_bytes()).decode_utf8() { + Ok(username) => username, + Err(_) => { + tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + + self.url.username().into() + } + } + } + + /// The percent-decoded database password. + pub fn password(&self) -> Option> { + match self + .url + .password() + .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) + { + Some(password) => Some(password), + None => self.url.password().map(|s| s.into()), + } + } + + /// Name of the database connected. Defaults to `mysql`. + pub fn dbname(&self) -> &str { + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("mysql"), + None => "mysql", + } + } + + /// The database host. If `socket` and `host` are not set, defaults to `localhost`. + pub fn host(&self) -> &str { + match (self.url.host(), self.url.host_str()) { + (Some(Host::Ipv6(_)), Some(host)) => { + // The `url` crate may return an IPv6 address in brackets, which must be stripped. + if host.starts_with('[') && host.ends_with(']') { + &host[1..host.len() - 1] + } else { + host + } + } + (_, Some(host)) => host, + _ => "localhost", + } + } + + /// If set, connected to the database through a Unix socket. + pub fn socket(&self) -> &Option { + &self.query_params.socket + } + + /// The database port, defaults to `3306`. + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(3306) + } + + /// The connection timeout. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout + } + + /// The pool check_out timeout + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout + } + + /// The socket timeout + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout + } + + /// Prefer socket connection + pub fn prefer_socket(&self) -> Option { + self.query_params.prefer_socket + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime + } + + pub(super) fn statement_cache_size(&self) -> usize { + self.query_params.statement_cache_size + } + + fn parse_query_params(url: &Url) -> Result { + #[cfg(feature = "mysql-connector")] + let mut ssl_opts = { + let mut ssl_opts = mysql_async::SslOpts::default(); + ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true); + ssl_opts + }; + + let mut connection_limit = None; + let mut use_ssl = false; + let mut socket = None; + let mut socket_timeout = None; + let mut connect_timeout = Some(Duration::from_secs(5)); + let mut pool_timeout = Some(Duration::from_secs(10)); + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); + let mut prefer_socket = None; + let mut statement_cache_size = 100; + let mut identity: Option<(Option, Option)> = None; + + for (k, v) in url.query_pairs() { + match k.as_ref() { + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connection_limit = Some(as_int); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + "sslcert" => { + use_ssl = true; + + #[cfg(feature = "mysql-connector")] + { + ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf())); + } + } + "sslidentity" => { + use_ssl = true; + + identity = match identity { + Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)), + None => Some((Some(Path::new(&*v).to_path_buf()), None)), + }; + } + "sslpassword" => { + use_ssl = true; + + identity = match identity { + Some((path, _)) => Some((path, Some(v.to_string()))), + None => Some((None, Some(v.to_string()))), + }; + } + "socket" => { + socket = Some(v.replace(['(', ')'], "")); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + socket_timeout = Some(Duration::from_secs(as_int)); + } + "prefer_socket" => { + let as_bool = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + prefer_socket = Some(as_bool) + } + "connect_timeout" => { + let as_int = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connect_timeout = match as_int { + 0 => None, + _ => Some(Duration::from_secs(as_int)), + }; + } + "pool_timeout" => { + let as_int = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + pool_timeout = match as_int { + 0 => None, + _ => Some(Duration::from_secs(as_int)), + }; + } + "sslaccept" => { + use_ssl = true; + match v.as_ref() { + "strict" => { + #[cfg(feature = "mysql-connector")] + { + ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false); + } + } + "accept_invalid_certs" => {} + _ => { + tracing::debug!( + message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`", + mode = &*v + ); + } + }; + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = &*k); + } + }; + } + + // Wrapping this in a block, as attributes on expressions are still experimental + // See: https://github.com/rust-lang/rust/issues/15701 + #[cfg(feature = "mysql-connector")] + { + ssl_opts = match identity { + Some((Some(path), Some(pw))) => { + let identity = mysql_async::ClientIdentity::new(path).with_password(pw); + ssl_opts.with_client_identity(Some(identity)) + } + Some((Some(path), None)) => { + let identity = mysql_async::ClientIdentity::new(path); + ssl_opts.with_client_identity(Some(identity)) + } + _ => ssl_opts, + }; + } + + Ok(MysqlUrlQueryParams { + #[cfg(feature = "mysql-connector")] + ssl_opts, + connection_limit, + use_ssl, + socket, + socket_timeout, + connect_timeout, + pool_timeout, + max_connection_lifetime, + max_idle_connection_lifetime, + prefer_socket, + statement_cache_size, + }) + } + + #[cfg(feature = "pooled")] + pub(crate) fn connection_limit(&self) -> Option { + self.query_params.connection_limit + } +} + +#[derive(Debug, Clone)] +pub(crate) struct MysqlUrlQueryParams { + pub(crate) connection_limit: Option, + pub(crate) use_ssl: bool, + pub(crate) socket: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, + pub(crate) prefer_socket: Option, + pub(crate) statement_cache_size: usize, + + #[cfg(feature = "mysql-connector")] + pub(crate) ssl_opts: mysql_async::SslOpts, +} diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 766be38b27e4..7a83e61218f6 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -1,6 +1,9 @@ mod conversion; mod error; +use super::postgres_wasm::{Hidden, SslAcceptMode, SslParams}; +pub(crate) use super::postgres_wasm::{PostgresFlavour, PostgresUrl}; + use crate::{ ast::{Query, Value}, connector::{metrics, queryable::*, ResultSet}, @@ -11,26 +14,19 @@ use async_trait::async_trait; use futures::{future::FutureExt, lock::Mutex}; use lru_cache::LruCache; use native_tls::{Certificate, Identity, TlsConnector}; -use percent_encoding::percent_decode; use postgres_native_tls::MakeTlsConnector; use std::{ - borrow::{Borrow, Cow}, + borrow::Borrow, fmt::{Debug, Display}, fs, future::Future, sync::atomic::{AtomicBool, Ordering}, time::Duration, }; -use tokio_postgres::{ - config::{ChannelBinding, SslMode}, - Client, Config, Statement, -}; -use url::{Host, Url}; +use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; pub use error::PostgresError; -pub(crate) const DEFAULT_SCHEMA: &str = "public"; - /// The underlying postgres driver. Only available with the `expose-drivers` /// Cargo feature. #[cfg(feature = "expose-drivers")] @@ -38,15 +34,6 @@ pub use tokio_postgres; use super::{IsolationLevel, Transaction}; -#[derive(Clone)] -struct Hidden(T); - -impl Debug for Hidden { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("") - } -} - struct PostgresClient(Client); impl Debug for PostgresClient { @@ -65,20 +52,6 @@ pub struct PostgreSql { is_healthy: AtomicBool, } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum SslAcceptMode { - Strict, - AcceptInvalidCerts, -} - -#[derive(Debug, Clone)] -pub struct SslParams { - certificate_file: Option, - identity_file: Option, - identity_password: Hidden>, - ssl_accept_mode: SslAcceptMode, -} - #[derive(Debug)] struct SslAuth { certificate: Hidden>, @@ -146,166 +119,7 @@ impl SslParams { } } -#[derive(Debug, Clone, Copy)] -pub enum PostgresFlavour { - Postgres, - Cockroach, - Unknown, -} - -impl PostgresFlavour { - /// Returns `true` if the postgres flavour is [`Postgres`]. - /// - /// [`Postgres`]: PostgresFlavour::Postgres - fn is_postgres(&self) -> bool { - matches!(self, Self::Postgres) - } - - /// Returns `true` if the postgres flavour is [`Cockroach`]. - /// - /// [`Cockroach`]: PostgresFlavour::Cockroach - fn is_cockroach(&self) -> bool { - matches!(self, Self::Cockroach) - } - - /// Returns `true` if the postgres flavour is [`Unknown`]. - /// - /// [`Unknown`]: PostgresFlavour::Unknown - fn is_unknown(&self) -> bool { - matches!(self, Self::Unknown) - } -} - -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug, Clone)] -pub struct PostgresUrl { - url: Url, - query_params: PostgresUrlQueryParams, - flavour: PostgresFlavour, -} - impl PostgresUrl { - /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection - /// parameters. - pub fn new(url: Url) -> Result { - let query_params = Self::parse_query_params(&url)?; - - Ok(Self { - url, - query_params, - flavour: PostgresFlavour::Unknown, - }) - } - - /// The bare `Url` to the database. - pub fn url(&self) -> &Url { - &self.url - } - - /// The percent-decoded database username. - pub fn username(&self) -> Cow { - match percent_decode(self.url.username().as_bytes()).decode_utf8() { - Ok(username) => username, - Err(_) => { - tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - - self.url.username().into() - } - } - } - - /// The database host. Taken first from the `host` query parameter, then - /// from the `host` part of the URL. For socket connections, the query - /// parameter must be used. - /// - /// If none of them are set, defaults to `localhost`. - pub fn host(&self) -> &str { - match (self.query_params.host.as_ref(), self.url.host_str(), self.url.host()) { - (Some(host), _, _) => host.as_str(), - (None, Some(""), _) => "localhost", - (None, None, _) => "localhost", - (None, Some(host), Some(Host::Ipv6(_))) => { - // The `url` crate may return an IPv6 address in brackets, which must be stripped. - if host.starts_with('[') && host.ends_with(']') { - &host[1..host.len() - 1] - } else { - host - } - } - (None, Some(host), _) => host, - } - } - - /// Name of the database connected. Defaults to `postgres`. - pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("postgres"), - None => "postgres", - } - } - - /// The percent-decoded database password. - pub fn password(&self) -> Cow { - match self - .url - .password() - .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) - { - Some(password) => password, - None => self.url.password().unwrap_or("").into(), - } - } - - /// The database port, defaults to `5432`. - pub fn port(&self) -> u16 { - self.url.port().unwrap_or(5432) - } - - /// The database schema, defaults to `public`. - pub fn schema(&self) -> &str { - self.query_params.schema.as_deref().unwrap_or(DEFAULT_SCHEMA) - } - - /// Whether the pgbouncer mode is enabled. - pub fn pg_bouncer(&self) -> bool { - self.query_params.pg_bouncer - } - - /// The connection timeout. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout - } - - /// Pool check_out timeout - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout - } - - /// The socket timeout - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime - } - - /// The custom application name - pub fn application_name(&self) -> Option<&str> { - self.query_params.application_name.as_deref() - } - - pub fn channel_binding(&self) -> ChannelBinding { - self.query_params.channel_binding - } - pub(crate) fn cache(&self) -> LruCache { if self.query_params.pg_bouncer { LruCache::new(0) @@ -314,208 +128,8 @@ impl PostgresUrl { } } - pub(crate) fn options(&self) -> Option<&str> { - self.query_params.options.as_deref() - } - - /// Sets whether the URL points to a Postgres, Cockroach or Unknown database. - /// This is used to avoid a network roundtrip at connection to set the search path. - /// - /// The different behaviours are: - /// - Postgres: Always avoid a network roundtrip by setting the search path through client connection parameters. - /// - Cockroach: Avoid a network roundtrip if the schema name is deemed "safe" (i.e. no escape quoting required). Otherwise, set the search path through a database query. - /// - Unknown: Always add a network roundtrip by setting the search path through a database query. - pub fn set_flavour(&mut self, flavour: PostgresFlavour) { - self.flavour = flavour; - } - - fn parse_query_params(url: &Url) -> Result { - let mut connection_limit = None; - let mut schema = None; - let mut certificate_file = None; - let mut identity_file = None; - let mut identity_password = None; - let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; - let mut ssl_mode = SslMode::Prefer; - let mut host = None; - let mut application_name = None; - let mut channel_binding = ChannelBinding::Prefer; - let mut socket_timeout = None; - let mut connect_timeout = Some(Duration::from_secs(5)); - let mut pool_timeout = Some(Duration::from_secs(10)); - let mut pg_bouncer = false; - let mut statement_cache_size = 100; - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); - let mut options = None; - - for (k, v) in url.query_pairs() { - match k.as_ref() { - "pgbouncer" => { - pg_bouncer = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslmode" => { - match v.as_ref() { - "disable" => ssl_mode = SslMode::Disable, - "prefer" => ssl_mode = SslMode::Prefer, - "require" => ssl_mode = SslMode::Require, - _ => { - tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); - } - }; - } - "sslcert" => { - certificate_file = Some(v.to_string()); - } - "sslidentity" => { - identity_file = Some(v.to_string()); - } - "sslpassword" => { - identity_password = Some(v.to_string()); - } - "statement_cache_size" => { - statement_cache_size = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslaccept" => { - match v.as_ref() { - "strict" => { - ssl_accept_mode = SslAcceptMode::Strict; - } - "accept_invalid_certs" => { - ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; - } - _ => { - tracing::debug!( - message = "Unsupported SSL accept mode, defaulting to `strict`", - mode = &*v - ); - - ssl_accept_mode = SslAcceptMode::Strict; - } - }; - } - "schema" => { - schema = Some(v.to_string()); - } - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - connection_limit = Some(as_int); - } - "host" => { - host = Some(v.to_string()); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - socket_timeout = Some(Duration::from_secs(as_int)); - } - "connect_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - connect_timeout = None; - } else { - connect_timeout = Some(Duration::from_secs(as_int)); - } - } - "pool_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - pool_timeout = None; - } else { - pool_timeout = Some(Duration::from_secs(as_int)); - } - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "application_name" => { - application_name = Some(v.to_string()); - } - "channel_binding" => { - match v.as_ref() { - "disable" => channel_binding = ChannelBinding::Disable, - "prefer" => channel_binding = ChannelBinding::Prefer, - "require" => channel_binding = ChannelBinding::Require, - _ => { - tracing::debug!( - message = "Unsupported Channel Binding {channel_binding}, defaulting to `prefer`", - channel_binding = &*v - ); - } - }; - } - "options" => { - options = Some(v.to_string()); - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = &*k); - } - }; - } - - Ok(PostgresUrlQueryParams { - ssl_params: SslParams { - certificate_file, - identity_file, - ssl_accept_mode, - identity_password: Hidden(identity_password), - }, - connection_limit, - schema, - ssl_mode, - host, - connect_timeout, - pool_timeout, - socket_timeout, - pg_bouncer, - statement_cache_size, - max_connection_lifetime, - max_idle_connection_lifetime, - application_name, - channel_binding, - options, - }) - } - - pub(crate) fn ssl_params(&self) -> &SslParams { - &self.query_params.ssl_params - } - - #[cfg(feature = "pooled")] - pub(crate) fn connection_limit(&self) -> Option { - self.query_params.connection_limit + pub fn channel_binding(&self) -> ChannelBinding { + self.query_params.channel_binding } /// On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. @@ -569,29 +183,6 @@ impl PostgresUrl { config } - - pub fn flavour(&self) -> PostgresFlavour { - self.flavour - } -} - -#[derive(Debug, Clone)] -pub(crate) struct PostgresUrlQueryParams { - ssl_params: SslParams, - connection_limit: Option, - schema: Option, - ssl_mode: SslMode, - pg_bouncer: bool, - host: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - statement_cache_size: usize, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, - application_name: Option, - channel_binding: ChannelBinding, - options: Option, } impl PostgreSql { diff --git a/quaint/src/connector/postgres_wasm.rs b/quaint/src/connector/postgres_wasm.rs new file mode 100644 index 000000000000..4c67b98cfa42 --- /dev/null +++ b/quaint/src/connector/postgres_wasm.rs @@ -0,0 +1,612 @@ +use std::{ + borrow::Cow, + fmt::{Debug, Display}, + time::Duration, +}; + +use percent_encoding::percent_decode; +use url::{Host, Url}; + +use crate::error::{Error, ErrorKind}; + +#[cfg(feature = "postgresql-connector")] +use tokio_postgres::config::{ChannelBinding, SslMode}; + +#[derive(Clone)] +pub(crate) struct Hidden(pub(crate) T); + +impl Debug for Hidden { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("") + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SslAcceptMode { + Strict, + AcceptInvalidCerts, +} + +#[derive(Debug, Clone)] +pub struct SslParams { + pub(super) certificate_file: Option, + pub(super) identity_file: Option, + pub(super) identity_password: Hidden>, + pub(super) ssl_accept_mode: SslAcceptMode, +} + +#[derive(Debug, Clone, Copy)] +pub enum PostgresFlavour { + Postgres, + Cockroach, + Unknown, +} + +impl PostgresFlavour { + /// Returns `true` if the postgres flavour is [`Postgres`]. + /// + /// [`Postgres`]: PostgresFlavour::Postgres + pub(super) fn is_postgres(&self) -> bool { + matches!(self, Self::Postgres) + } + + /// Returns `true` if the postgres flavour is [`Cockroach`]. + /// + /// [`Cockroach`]: PostgresFlavour::Cockroach + pub(super) fn is_cockroach(&self) -> bool { + matches!(self, Self::Cockroach) + } + + /// Returns `true` if the postgres flavour is [`Unknown`]. + /// + /// [`Unknown`]: PostgresFlavour::Unknown + pub(super) fn is_unknown(&self) -> bool { + matches!(self, Self::Unknown) + } +} + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug, Clone)] +pub struct PostgresUrl { + pub(super) url: Url, + pub(super) query_params: PostgresUrlQueryParams, + pub(super) flavour: PostgresFlavour, +} + +pub(crate) const DEFAULT_SCHEMA: &str = "public"; + +impl PostgresUrl { + /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection + /// parameters. + pub fn new(url: Url) -> Result { + let query_params = Self::parse_query_params(&url)?; + + Ok(Self { + url, + query_params, + flavour: PostgresFlavour::Unknown, + }) + } + + /// The bare `Url` to the database. + pub fn url(&self) -> &Url { + &self.url + } + + /// The percent-decoded database username. + pub fn username(&self) -> Cow { + match percent_decode(self.url.username().as_bytes()).decode_utf8() { + Ok(username) => username, + Err(_) => { + tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + + self.url.username().into() + } + } + } + + /// The database host. Taken first from the `host` query parameter, then + /// from the `host` part of the URL. For socket connections, the query + /// parameter must be used. + /// + /// If none of them are set, defaults to `localhost`. + pub fn host(&self) -> &str { + match (self.query_params.host.as_ref(), self.url.host_str(), self.url.host()) { + (Some(host), _, _) => host.as_str(), + (None, Some(""), _) => "localhost", + (None, None, _) => "localhost", + (None, Some(host), Some(Host::Ipv6(_))) => { + // The `url` crate may return an IPv6 address in brackets, which must be stripped. + if host.starts_with('[') && host.ends_with(']') { + &host[1..host.len() - 1] + } else { + host + } + } + (None, Some(host), _) => host, + } + } + + /// Name of the database connected. Defaults to `postgres`. + pub fn dbname(&self) -> &str { + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("postgres"), + None => "postgres", + } + } + + /// The percent-decoded database password. + pub fn password(&self) -> Cow { + match self + .url + .password() + .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) + { + Some(password) => password, + None => self.url.password().unwrap_or("").into(), + } + } + + /// The database port, defaults to `5432`. + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(5432) + } + + /// The database schema, defaults to `public`. + pub fn schema(&self) -> &str { + self.query_params.schema.as_deref().unwrap_or(DEFAULT_SCHEMA) + } + + /// Whether the pgbouncer mode is enabled. + pub fn pg_bouncer(&self) -> bool { + self.query_params.pg_bouncer + } + + /// The connection timeout. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout + } + + /// Pool check_out timeout + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout + } + + /// The socket timeout + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime + } + + /// The custom application name + pub fn application_name(&self) -> Option<&str> { + self.query_params.application_name.as_deref() + } + + pub(crate) fn options(&self) -> Option<&str> { + self.query_params.options.as_deref() + } + + /// Sets whether the URL points to a Postgres, Cockroach or Unknown database. + /// This is used to avoid a network roundtrip at connection to set the search path. + /// + /// The different behaviours are: + /// - Postgres: Always avoid a network roundtrip by setting the search path through client connection parameters. + /// - Cockroach: Avoid a network roundtrip if the schema name is deemed "safe" (i.e. no escape quoting required). Otherwise, set the search path through a database query. + /// - Unknown: Always add a network roundtrip by setting the search path through a database query. + pub fn set_flavour(&mut self, flavour: PostgresFlavour) { + self.flavour = flavour; + } + + fn parse_query_params(url: &Url) -> Result { + #[cfg(feature = "postgresql-connector")] + let mut ssl_mode = SslMode::Prefer; + #[cfg(feature = "postgresql-connector")] + let mut channel_binding = ChannelBinding::Prefer; + + let mut connection_limit = None; + let mut schema = None; + let mut certificate_file = None; + let mut identity_file = None; + let mut identity_password = None; + let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; + let mut host = None; + let mut application_name = None; + let mut socket_timeout = None; + let mut connect_timeout = Some(Duration::from_secs(5)); + let mut pool_timeout = Some(Duration::from_secs(10)); + let mut pg_bouncer = false; + let mut statement_cache_size = 100; + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); + let mut options = None; + + for (k, v) in url.query_pairs() { + match k.as_ref() { + "pgbouncer" => { + pg_bouncer = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + #[cfg(feature = "postgresql-connector")] + "sslmode" => { + match v.as_ref() { + "disable" => ssl_mode = SslMode::Disable, + "prefer" => ssl_mode = SslMode::Prefer, + "require" => ssl_mode = SslMode::Require, + _ => { + tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); + } + }; + } + "sslcert" => { + certificate_file = Some(v.to_string()); + } + "sslidentity" => { + identity_file = Some(v.to_string()); + } + "sslpassword" => { + identity_password = Some(v.to_string()); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + "sslaccept" => { + match v.as_ref() { + "strict" => { + ssl_accept_mode = SslAcceptMode::Strict; + } + "accept_invalid_certs" => { + ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; + } + _ => { + tracing::debug!( + message = "Unsupported SSL accept mode, defaulting to `strict`", + mode = &*v + ); + + ssl_accept_mode = SslAcceptMode::Strict; + } + }; + } + "schema" => { + schema = Some(v.to_string()); + } + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + connection_limit = Some(as_int); + } + "host" => { + host = Some(v.to_string()); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + socket_timeout = Some(Duration::from_secs(as_int)); + } + "connect_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + connect_timeout = None; + } else { + connect_timeout = Some(Duration::from_secs(as_int)); + } + } + "pool_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + pool_timeout = None; + } else { + pool_timeout = Some(Duration::from_secs(as_int)); + } + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "application_name" => { + application_name = Some(v.to_string()); + } + #[cfg(feature = "postgresql-connector")] + "channel_binding" => { + match v.as_ref() { + "disable" => channel_binding = ChannelBinding::Disable, + "prefer" => channel_binding = ChannelBinding::Prefer, + "require" => channel_binding = ChannelBinding::Require, + _ => { + tracing::debug!( + message = "Unsupported Channel Binding {channel_binding}, defaulting to `prefer`", + channel_binding = &*v + ); + } + }; + } + "options" => { + options = Some(v.to_string()); + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = &*k); + } + }; + } + + Ok(PostgresUrlQueryParams { + ssl_params: SslParams { + certificate_file, + identity_file, + ssl_accept_mode, + identity_password: Hidden(identity_password), + }, + connection_limit, + schema, + host, + connect_timeout, + pool_timeout, + socket_timeout, + pg_bouncer, + statement_cache_size, + max_connection_lifetime, + max_idle_connection_lifetime, + application_name, + options, + #[cfg(feature = "postgresql-connector")] + channel_binding, + #[cfg(feature = "postgresql-connector")] + ssl_mode, + }) + } + + pub(crate) fn ssl_params(&self) -> &SslParams { + &self.query_params.ssl_params + } + + #[cfg(feature = "pooled")] + pub(crate) fn connection_limit(&self) -> Option { + self.query_params.connection_limit + } + + pub fn flavour(&self) -> PostgresFlavour { + self.flavour + } +} + +#[derive(Debug, Clone)] +pub(crate) struct PostgresUrlQueryParams { + pub(crate) ssl_params: SslParams, + pub(crate) connection_limit: Option, + pub(crate) schema: Option, + pub(crate) pg_bouncer: bool, + pub(crate) host: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) statement_cache_size: usize, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, + pub(crate) application_name: Option, + pub(crate) options: Option, + + #[cfg(feature = "postgresql-connector")] + pub(crate) channel_binding: ChannelBinding, + + #[cfg(feature = "postgresql-connector")] + pub(crate) ssl_mode: SslMode, +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct CockroachSearchPath<'a>(&'a str); + +impl Display for CockroachSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.0) + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct PostgresSearchPath<'a>(&'a str); + +impl Display for PostgresSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("\"")?; + f.write_str(self.0)?; + f.write_str("\"")?; + + Ok(()) + } +} + +// A SetSearchPath statement (Display-impl) for connection initialization. +struct SetSearchPath<'a>(Option<&'a str>); + +impl Display for SetSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(schema) = self.0 { + f.write_str("SET search_path = \"")?; + f.write_str(schema)?; + f.write_str("\";\n")?; + } + + Ok(()) + } +} + +/// Sorted list of CockroachDB's reserved keywords. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_KEYWORDS: [&str; 79] = [ + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "both", + "case", + "cast", + "check", + "collate", + "column", + "concurrently", + "constraint", + "create", + "current_catalog", + "current_date", + "current_role", + "current_schema", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "fetch", + "for", + "foreign", + "from", + "grant", + "group", + "having", + "in", + "initially", + "intersect", + "into", + "lateral", + "leading", + "limit", + "localtime", + "localtimestamp", + "not", + "null", + "offset", + "on", + "only", + "or", + "order", + "placing", + "primary", + "references", + "returning", + "select", + "session_user", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "variadic", + "when", + "where", + "window", + "with", +]; + +/// Sorted list of CockroachDB's reserved type function names. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ + "authorization", + "collation", + "cross", + "full", + "ilike", + "inner", + "is", + "isnull", + "join", + "left", + "like", + "natural", + "none", + "notnull", + "outer", + "overlaps", + "right", + "similar", +]; + +/// Returns true if a Postgres identifier is considered "safe". +/// +/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. +/// +/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers +fn is_safe_identifier(ident: &str) -> bool { + if ident.is_empty() { + return false; + } + + // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. + if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { + return false; + } + + let mut chars = ident.chars(); + + let first = chars.next().unwrap(); + + // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). + if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { + return false; + } + + for c in chars { + // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). + if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { + return false; + } + } + + true +} diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index 3a1ef72b4883..fc993c1eaf0e 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -1,6 +1,7 @@ mod conversion; mod error; +pub(crate) use super::sqlite_wasm::{SqliteParams, DEFAULT_SQLITE_SCHEMA_NAME}; pub use error::SqliteError; pub use rusqlite::{params_from_iter, version as sqlite_version}; @@ -13,11 +14,9 @@ use crate::{ visitor::{self, Visitor}, }; use async_trait::async_trait; -use std::{convert::TryFrom, path::Path, time::Duration}; +use std::convert::TryFrom; use tokio::sync::Mutex; -pub(crate) const DEFAULT_SQLITE_SCHEMA_NAME: &str = "main"; - /// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. #[cfg(feature = "expose-drivers")] pub use rusqlite; @@ -27,105 +26,6 @@ pub struct Sqlite { pub(crate) client: Mutex, } -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug)] -pub struct SqliteParams { - pub connection_limit: Option, - /// This is not a `PathBuf` because we need to `ATTACH` the database to the path, and this can - /// only be done with UTF-8 paths. - pub file_path: String, - pub db_name: String, - pub socket_timeout: Option, - pub max_connection_lifetime: Option, - pub max_idle_connection_lifetime: Option, -} - -impl TryFrom<&str> for SqliteParams { - type Error = Error; - - fn try_from(path: &str) -> crate::Result { - let path = if path.starts_with("file:") { - path.trim_start_matches("file:") - } else { - path.trim_start_matches("sqlite:") - }; - - let path_parts: Vec<&str> = path.split('?').collect(); - let path_str = path_parts[0]; - let path = Path::new(path_str); - - if path.is_dir() { - Err(Error::builder(ErrorKind::DatabaseUrlIsInvalid(path.to_str().unwrap().to_string())).build()) - } else { - let mut connection_limit = None; - let mut socket_timeout = None; - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = None; - - if path_parts.len() > 1 { - let params = path_parts.last().unwrap().split('&').map(|kv| { - let splitted: Vec<&str> = kv.split('=').collect(); - (splitted[0], splitted[1]) - }); - - for (k, v) in params { - match k { - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connection_limit = Some(as_int); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - socket_timeout = Some(Duration::from_secs(as_int)); - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = k); - } - }; - } - } - - Ok(Self { - connection_limit, - file_path: path_str.to_owned(), - db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned(), - socket_timeout, - max_connection_lifetime, - max_idle_connection_lifetime, - }) - } - } -} - impl TryFrom<&str> for Sqlite { type Error = Error; diff --git a/quaint/src/connector/sqlite_wasm.rs b/quaint/src/connector/sqlite_wasm.rs new file mode 100644 index 000000000000..10c174480785 --- /dev/null +++ b/quaint/src/connector/sqlite_wasm.rs @@ -0,0 +1,103 @@ +use crate::error::{Error, ErrorKind}; +use std::{convert::TryFrom, path::Path, time::Duration}; + +pub(crate) const DEFAULT_SQLITE_SCHEMA_NAME: &str = "main"; + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug)] +pub struct SqliteParams { + pub connection_limit: Option, + /// This is not a `PathBuf` because we need to `ATTACH` the database to the path, and this can + /// only be done with UTF-8 paths. + pub file_path: String, + pub db_name: String, + pub socket_timeout: Option, + pub max_connection_lifetime: Option, + pub max_idle_connection_lifetime: Option, +} + +impl TryFrom<&str> for SqliteParams { + type Error = Error; + + fn try_from(path: &str) -> crate::Result { + let path = if path.starts_with("file:") { + path.trim_start_matches("file:") + } else { + path.trim_start_matches("sqlite:") + }; + + let path_parts: Vec<&str> = path.split('?').collect(); + let path_str = path_parts[0]; + let path = Path::new(path_str); + + if path.is_dir() { + Err(Error::builder(ErrorKind::DatabaseUrlIsInvalid(path.to_str().unwrap().to_string())).build()) + } else { + let mut connection_limit = None; + let mut socket_timeout = None; + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = None; + + if path_parts.len() > 1 { + let params = path_parts.last().unwrap().split('&').map(|kv| { + let splitted: Vec<&str> = kv.split('=').collect(); + (splitted[0], splitted[1]) + }); + + for (k, v) in params { + match k { + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connection_limit = Some(as_int); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + socket_timeout = Some(Duration::from_secs(as_int)); + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = k); + } + }; + } + } + + Ok(Self { + connection_limit, + file_path: path_str.to_owned(), + db_name: DEFAULT_SQLITE_SCHEMA_NAME.to_owned(), + socket_timeout, + max_connection_lifetime, + max_idle_connection_lifetime, + }) + } + } +} diff --git a/quaint/src/error.rs b/quaint/src/error.rs index 705bb6b37ee0..785fcc22ffe3 100644 --- a/quaint/src/error.rs +++ b/quaint/src/error.rs @@ -6,9 +6,9 @@ use thiserror::Error; #[cfg(feature = "pooled")] use std::time::Duration; -pub use crate::connector::mysql::MysqlError; -pub use crate::connector::postgres::PostgresError; -pub use crate::connector::sqlite::SqliteError; +// pub use crate::connector::mysql::MysqlError; +// pub use crate::connector::postgres::PostgresError; +// pub use crate::connector::sqlite::SqliteError; #[derive(Debug, PartialEq, Eq)] pub enum DatabaseConstraint { diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index c0aa8c93b75d..c31fd44fbcae 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -1,8 +1,8 @@ -#[cfg(feature = "mssql")] +#[cfg(feature = "mssql-connector")] use crate::connector::MssqlUrl; -#[cfg(feature = "mysql")] +#[cfg(feature = "mysql-connector")] use crate::connector::MysqlUrl; -#[cfg(feature = "postgresql")] +#[cfg(feature = "postgresql-connector")] use crate::connector::PostgresUrl; use crate::{ ast, @@ -97,7 +97,7 @@ impl Manager for QuaintManager { async fn connect(&self) -> crate::Result { let conn = match self { - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-connector")] QuaintManager::Sqlite { url, .. } => { use crate::connector::Sqlite; @@ -106,19 +106,19 @@ impl Manager for QuaintManager { Ok(Box::new(conn) as Self::Connection) } - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-connector")] QuaintManager::Mysql { url } => { use crate::connector::Mysql; Ok(Box::new(Mysql::new(url.clone()).await?) as Self::Connection) } - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-connector")] QuaintManager::Postgres { url } => { use crate::connector::PostgreSql; Ok(Box::new(PostgreSql::new(url.clone()).await?) as Self::Connection) } - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-connector")] QuaintManager::Mssql { url } => { use crate::connector::Mssql; Ok(Box::new(Mssql::new(url.clone()).await?) as Self::Connection) @@ -146,7 +146,7 @@ mod tests { use crate::pooled::Quaint; #[tokio::test] - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-connector")] async fn mysql_default_connection_limit() { let conn_string = std::env::var("TEST_MYSQL").expect("TEST_MYSQL connection string not set."); @@ -156,7 +156,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-connector")] async fn mysql_custom_connection_limit() { let conn_string = format!( "{}?connection_limit=10", @@ -169,7 +169,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-connector")] async fn psql_default_connection_limit() { let conn_string = std::env::var("TEST_PSQL").expect("TEST_PSQL connection string not set."); @@ -179,7 +179,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-connector")] async fn psql_custom_connection_limit() { let conn_string = format!( "{}?connection_limit=10", @@ -192,7 +192,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-connector")] async fn mssql_default_connection_limit() { let conn_string = std::env::var("TEST_MSSQL").expect("TEST_MSSQL connection string not set."); @@ -202,7 +202,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-connector")] async fn mssql_custom_connection_limit() { let conn_string = format!( "{};connectionLimit=10", @@ -215,7 +215,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-connector")] async fn test_default_connection_limit() { let conn_string = "file:db/test.db".to_string(); let pool = Quaint::builder(&conn_string).unwrap().build(); @@ -224,7 +224,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-connector")] async fn test_custom_connection_limit() { let conn_string = "file:db/test.db?connection_limit=10".to_string(); let pool = Quaint::builder(&conn_string).unwrap().build(); diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 82042f58010b..2f234e40fd74 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -130,27 +130,27 @@ impl Quaint { #[allow(unreachable_code)] pub async fn new(url_str: &str) -> crate::Result { let inner = match url_str { - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-connector")] s if s.starts_with("file") => { let params = connector::SqliteParams::try_from(s)?; let sqlite = connector::Sqlite::new(¶ms.file_path)?; Arc::new(sqlite) as Arc } - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-connector")] s if s.starts_with("mysql") => { let url = connector::MysqlUrl::new(url::Url::parse(s)?)?; let mysql = connector::Mysql::new(url).await?; Arc::new(mysql) as Arc } - #[cfg(feature = "postgresql")] + #[cfg(feature = "postgresql-connector")] s if s.starts_with("postgres") || s.starts_with("postgresql") => { let url = connector::PostgresUrl::new(url::Url::parse(s)?)?; let psql = connector::PostgreSql::new(url).await?; Arc::new(psql) as Arc } - #[cfg(feature = "mssql")] + #[cfg(feature = "mssql-connector")] s if s.starts_with("jdbc:sqlserver") | s.starts_with("sqlserver") => { let url = connector::MssqlUrl::new(s)?; let psql = connector::Mssql::new(url).await?; @@ -166,7 +166,7 @@ impl Quaint { Ok(Self { inner, connection_info }) } - #[cfg(feature = "sqlite")] + #[cfg(feature = "sqlite-connector")] /// Open a new SQLite database in memory. pub fn new_in_memory() -> crate::Result { Ok(Quaint { From 055e696e40adb7294da2337ab86ded2d333ef5a8 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 10 Nov 2023 15:02:09 +0100 Subject: [PATCH 02/34] feat(quaint): split postgres connector into native and wasm submodules --- quaint/src/connector.rs | 23 +- quaint/src/connector/postgres.rs | 1187 +---------------- .../postgres/{ => native}/conversion.rs | 0 .../{ => native}/conversion/decimal.rs | 0 quaint/src/connector/postgres/native/error.rs | 126 ++ quaint/src/connector/postgres/native/mod.rs | 1184 ++++++++++++++++ quaint/src/connector/postgres/wasm/common.rs | 612 +++++++++ .../connector/postgres/{ => wasm}/error.rs | 124 +- quaint/src/connector/postgres/wasm/mod.rs | 6 + quaint/src/error.rs | 2 +- 10 files changed, 1950 insertions(+), 1314 deletions(-) rename quaint/src/connector/postgres/{ => native}/conversion.rs (100%) rename quaint/src/connector/postgres/{ => native}/conversion/decimal.rs (100%) create mode 100644 quaint/src/connector/postgres/native/error.rs create mode 100644 quaint/src/connector/postgres/native/mod.rs create mode 100644 quaint/src/connector/postgres/wasm/common.rs rename quaint/src/connector/postgres/{ => wasm}/error.rs (66%) create mode 100644 quaint/src/connector/postgres/wasm/mod.rs diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index 898aac8fcb46..71bba2d098ed 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -31,10 +31,10 @@ pub(crate) mod mssql_wasm; pub(crate) mod mysql; #[cfg(feature = "mysql")] pub(crate) mod mysql_wasm; -#[cfg(feature = "postgresql-connector")] -pub(crate) mod postgres; -#[cfg(feature = "postgresql")] -pub(crate) mod postgres_wasm; +// #[cfg(feature = "postgresql-connector")] +// pub(crate) mod postgres; +// #[cfg(feature = "postgresql")] +// pub(crate) mod postgres_wasm; #[cfg(feature = "sqlite-connector")] pub(crate) mod sqlite; #[cfg(feature = "sqlite")] @@ -44,10 +44,10 @@ pub(crate) mod sqlite_wasm; pub use self::mysql::*; #[cfg(feature = "mysql")] pub use self::mysql_wasm::*; -#[cfg(feature = "postgresql-connector")] -pub use self::postgres::*; -#[cfg(feature = "postgresql")] -pub use self::postgres_wasm::*; +// #[cfg(feature = "postgresql-connector")] +// pub use self::postgres::*; +// #[cfg(feature = "postgresql")] +// pub use self::postgres_wasm::*; #[cfg(feature = "mssql-connector")] pub use mssql::*; #[cfg(feature = "mssql")] @@ -70,3 +70,10 @@ pub use transaction::*; pub(crate) use type_identifier::*; pub use self::metrics::query; + +#[cfg(feature = "postgresql")] +pub(crate) mod postgres; +#[cfg(feature = "postgresql-connector")] +pub use postgres::native::*; +#[cfg(feature = "postgresql")] +pub use postgres::wasm::*; diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 7a83e61218f6..9f4d4d496f2b 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -1,1184 +1,7 @@ -mod conversion; -mod error; +pub use wasm::error::PostgresError; -use super::postgres_wasm::{Hidden, SslAcceptMode, SslParams}; -pub(crate) use super::postgres_wasm::{PostgresFlavour, PostgresUrl}; +#[cfg(feature = "postgresql")] +pub(crate) mod wasm; -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use futures::{future::FutureExt, lock::Mutex}; -use lru_cache::LruCache; -use native_tls::{Certificate, Identity, TlsConnector}; -use postgres_native_tls::MakeTlsConnector; -use std::{ - borrow::Borrow, - fmt::{Debug, Display}, - fs, - future::Future, - sync::atomic::{AtomicBool, Ordering}, - time::Duration, -}; -use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; - -pub use error::PostgresError; - -/// The underlying postgres driver. Only available with the `expose-drivers` -/// Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use tokio_postgres; - -use super::{IsolationLevel, Transaction}; - -struct PostgresClient(Client); - -impl Debug for PostgresClient { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("PostgresClient") - } -} - -/// A connector interface for the PostgreSQL database. -#[derive(Debug)] -pub struct PostgreSql { - client: PostgresClient, - pg_bouncer: bool, - socket_timeout: Option, - statement_cache: Mutex>, - is_healthy: AtomicBool, -} - -#[derive(Debug)] -struct SslAuth { - certificate: Hidden>, - identity: Hidden>, - ssl_accept_mode: SslAcceptMode, -} - -impl Default for SslAuth { - fn default() -> Self { - Self { - certificate: Hidden(None), - identity: Hidden(None), - ssl_accept_mode: SslAcceptMode::AcceptInvalidCerts, - } - } -} - -impl SslAuth { - fn certificate(&mut self, certificate: Certificate) -> &mut Self { - self.certificate = Hidden(Some(certificate)); - self - } - - fn identity(&mut self, identity: Identity) -> &mut Self { - self.identity = Hidden(Some(identity)); - self - } - - fn accept_mode(&mut self, mode: SslAcceptMode) -> &mut Self { - self.ssl_accept_mode = mode; - self - } -} - -impl SslParams { - async fn into_auth(self) -> crate::Result { - let mut auth = SslAuth::default(); - auth.accept_mode(self.ssl_accept_mode); - - if let Some(ref cert_file) = self.certificate_file { - let cert = fs::read(cert_file).map_err(|err| { - Error::builder(ErrorKind::TlsError { - message: format!("cert file not found ({err})"), - }) - .build() - })?; - - auth.certificate(Certificate::from_pem(&cert)?); - } - - if let Some(ref identity_file) = self.identity_file { - let db = fs::read(identity_file).map_err(|err| { - Error::builder(ErrorKind::TlsError { - message: format!("identity file not found ({err})"), - }) - .build() - })?; - let password = self.identity_password.0.as_deref().unwrap_or(""); - let identity = Identity::from_pkcs12(&db, password)?; - - auth.identity(identity); - } - - Ok(auth) - } -} - -impl PostgresUrl { - pub(crate) fn cache(&self) -> LruCache { - if self.query_params.pg_bouncer { - LruCache::new(0) - } else { - LruCache::new(self.query_params.statement_cache_size) - } - } - - pub fn channel_binding(&self) -> ChannelBinding { - self.query_params.channel_binding - } - - /// On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. - /// We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. - /// To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. - fn set_search_path(&self, config: &mut Config) { - // PGBouncer does not support the search_path connection parameter. - // https://www.pgbouncer.org/config.html#ignore_startup_parameters - if self.query_params.pg_bouncer { - return; - } - - if let Some(schema) = &self.query_params.schema { - if self.flavour().is_cockroach() && is_safe_identifier(schema) { - config.search_path(CockroachSearchPath(schema).to_string()); - } - - if self.flavour().is_postgres() { - config.search_path(PostgresSearchPath(schema).to_string()); - } - } - } - - pub(crate) fn to_config(&self) -> Config { - let mut config = Config::new(); - - config.user(self.username().borrow()); - config.password(self.password().borrow() as &str); - config.host(self.host()); - config.port(self.port()); - config.dbname(self.dbname()); - config.pgbouncer_mode(self.query_params.pg_bouncer); - - if let Some(options) = self.options() { - config.options(options); - } - - if let Some(application_name) = self.application_name() { - config.application_name(application_name); - } - - if let Some(connect_timeout) = self.query_params.connect_timeout { - config.connect_timeout(connect_timeout); - } - - self.set_search_path(&mut config); - - config.ssl_mode(self.query_params.ssl_mode); - - config.channel_binding(self.query_params.channel_binding); - - config - } -} - -impl PostgreSql { - /// Create a new connection to the database. - pub async fn new(url: PostgresUrl) -> crate::Result { - let config = url.to_config(); - - let mut tls_builder = TlsConnector::builder(); - - { - let ssl_params = url.ssl_params(); - let auth = ssl_params.to_owned().into_auth().await?; - - if let Some(certificate) = auth.certificate.0 { - tls_builder.add_root_certificate(certificate); - } - - tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); - - if let Some(identity) = auth.identity.0 { - tls_builder.identity(identity); - } - } - - let tls = MakeTlsConnector::new(tls_builder.build()?); - let (client, conn) = super::timeout::connect(url.connect_timeout(), config.connect(tls)).await?; - - tokio::spawn(conn.map(|r| match r { - Ok(_) => (), - Err(e) => { - tracing::error!("Error in PostgreSQL connection: {:?}", e); - } - })); - - // On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. - // We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. - // To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. - // Finally, to ensure backward compatibility, we keep sending a database query in case the flavour is set to Unknown. - if let Some(schema) = &url.query_params.schema { - // PGBouncer does not support the search_path connection parameter. - // https://www.pgbouncer.org/config.html#ignore_startup_parameters - if url.query_params.pg_bouncer - || url.flavour().is_unknown() - || (url.flavour().is_cockroach() && !is_safe_identifier(schema)) - { - let session_variables = format!( - r##"{set_search_path}"##, - set_search_path = SetSearchPath(url.query_params.schema.as_deref()) - ); - - client.simple_query(session_variables.as_str()).await?; - } - } - - Ok(Self { - client: PostgresClient(client), - socket_timeout: url.query_params.socket_timeout, - pg_bouncer: url.query_params.pg_bouncer, - statement_cache: Mutex::new(url.cache()), - is_healthy: AtomicBool::new(true), - }) - } - - /// The underlying tokio_postgres::Client. Only available with the - /// `expose-drivers` Cargo feature. This is a lower level API when you need - /// to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn client(&self) -> &tokio_postgres::Client { - &self.client.0 - } - - async fn fetch_cached(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - let mut cache = self.statement_cache.lock().await; - let capacity = cache.capacity(); - let stored = cache.len(); - - match cache.get_mut(sql) { - Some(stmt) => { - tracing::trace!( - message = "CACHE HIT!", - query = sql, - capacity = capacity, - stored = stored, - ); - - Ok(stmt.clone()) // arc'd - } - None => { - tracing::trace!( - message = "CACHE MISS!", - query = sql, - capacity = capacity, - stored = stored, - ); - - let param_types = conversion::params_to_types(params); - let stmt = self.perform_io(self.client.0.prepare_typed(sql, ¶m_types)).await?; - - cache.insert(sql.to_string(), stmt.clone()); - - Ok(stmt) - } - } - } - - async fn perform_io(&self, fut: F) -> crate::Result - where - F: Future>, - { - match super::timeout::socket(self.socket_timeout, fut).await { - Err(e) if e.is_closed() => { - self.is_healthy.store(false, Ordering::SeqCst); - Err(e) - } - res => res, - } - } - - fn check_bind_variables_len(&self, params: &[Value<'_>]) -> crate::Result<()> { - if params.len() > i16::MAX as usize { - // tokio_postgres would return an error here. Let's avoid calling the driver - // and return an error early. - let kind = ErrorKind::QueryInvalidInput(format!( - "too many bind variables in prepared statement, expected maximum of {}, received {}", - i16::MAX, - params.len() - )); - Err(Error::builder(kind).build()) - } else { - Ok(()) - } - } -} - -// A SearchPath connection parameter (Display-impl) for connection initialization. -struct CockroachSearchPath<'a>(&'a str); - -impl Display for CockroachSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.0) - } -} - -// A SearchPath connection parameter (Display-impl) for connection initialization. -struct PostgresSearchPath<'a>(&'a str); - -impl Display for PostgresSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("\"")?; - f.write_str(self.0)?; - f.write_str("\"")?; - - Ok(()) - } -} - -// A SetSearchPath statement (Display-impl) for connection initialization. -struct SetSearchPath<'a>(Option<&'a str>); - -impl Display for SetSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if let Some(schema) = self.0 { - f.write_str("SET search_path = \"")?; - f.write_str(schema)?; - f.write_str("\";\n")?; - } - - Ok(()) - } -} - -impl_default_TransactionCapable!(PostgreSql); - -#[async_trait] -impl Queryable for PostgreSql { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Postgres::build(q)?; - - self.query_raw(sql.as_str(), ¶ms[..]).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.query_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); - - for row in rows { - result.rows.push(row.get_result_row()?); - } - - Ok(result) - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.query_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let rows = self - .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); - - for row in rows { - result.rows.push(row.get_result_row()?); - } - - Ok(result) - }) - .await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Postgres::build(q)?; - - self.execute_raw(sql.as_str(), ¶ms[..]).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.execute_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, &[]).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - Ok(changes) - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.check_bind_variables_len(params)?; - - metrics::query("postgres.execute_raw", sql, params, move || async move { - let stmt = self.fetch_cached(sql, params).await?; - - if stmt.params().len() != params.len() { - let kind = ErrorKind::IncorrectNumberOfParameters { - expected: stmt.params().len(), - actual: params.len(), - }; - - return Err(Error::builder(kind).build()); - } - - let changes = self - .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) - .await?; - - Ok(changes) - }) - .await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("postgres.raw_cmd", cmd, &[], move || async move { - self.perform_io(self.client.0.simple_query(cmd)).await?; - Ok(()) - }) - .await - } - - async fn version(&self) -> crate::Result> { - let query = r#"SELECT version()"#; - let rows = self.query_raw(query, &[]).await?; - - let version_string = rows - .get(0) - .and_then(|row| row.get("version").and_then(|version| version.to_string())); - - Ok(version_string) - } - - fn is_healthy(&self) -> bool { - self.is_healthy.load(Ordering::SeqCst) - } - - async fn server_reset_query(&self, tx: &dyn Transaction) -> crate::Result<()> { - if self.pg_bouncer { - tx.raw_cmd("DEALLOCATE ALL").await - } else { - Ok(()) - } - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - if matches!(isolation_level, IsolationLevel::Snapshot) { - return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); - } - - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await?; - - Ok(()) - } - - fn requires_isolation_first(&self) -> bool { - false - } -} - -/// Sorted list of CockroachDB's reserved keywords. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_KEYWORDS: [&str; 79] = [ - "all", - "analyse", - "analyze", - "and", - "any", - "array", - "as", - "asc", - "asymmetric", - "both", - "case", - "cast", - "check", - "collate", - "column", - "concurrently", - "constraint", - "create", - "current_catalog", - "current_date", - "current_role", - "current_schema", - "current_time", - "current_timestamp", - "current_user", - "default", - "deferrable", - "desc", - "distinct", - "do", - "else", - "end", - "except", - "false", - "fetch", - "for", - "foreign", - "from", - "grant", - "group", - "having", - "in", - "initially", - "intersect", - "into", - "lateral", - "leading", - "limit", - "localtime", - "localtimestamp", - "not", - "null", - "offset", - "on", - "only", - "or", - "order", - "placing", - "primary", - "references", - "returning", - "select", - "session_user", - "some", - "symmetric", - "table", - "then", - "to", - "trailing", - "true", - "union", - "unique", - "user", - "using", - "variadic", - "when", - "where", - "window", - "with", -]; - -/// Sorted list of CockroachDB's reserved type function names. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ - "authorization", - "collation", - "cross", - "full", - "ilike", - "inner", - "is", - "isnull", - "join", - "left", - "like", - "natural", - "none", - "notnull", - "outer", - "overlaps", - "right", - "similar", -]; - -/// Returns true if a Postgres identifier is considered "safe". -/// -/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. -/// -/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS -/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers -fn is_safe_identifier(ident: &str) -> bool { - if ident.is_empty() { - return false; - } - - // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. - if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { - return false; - } - - let mut chars = ident.chars(); - - let first = chars.next().unwrap(); - - // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). - if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { - return false; - } - - for c in chars { - // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). - if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { - return false; - } - } - - true -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::tests::test_api::postgres::CONN_STR; - use crate::tests::test_api::CRDB_CONN_STR; - use crate::{connector::Queryable, error::*, single::Quaint}; - use url::Url; - - #[test] - fn should_parse_socket_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("/var/run/psql.sock", url.host()); - } - - #[test] - fn should_parse_escaped_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("/var/run/postgresql", url.host()); - } - - #[test] - fn should_allow_changing_of_cache_size() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); - assert_eq!(420, url.cache().capacity()); - } - - #[test] - fn should_have_default_cache_size() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); - assert_eq!(100, url.cache().capacity()); - } - - #[test] - fn should_have_application_name() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); - assert_eq!(Some("test"), url.application_name()); - } - - #[test] - fn should_have_channel_binding() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Require, url.channel_binding()); - } - - #[test] - fn should_have_default_channel_binding() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - } - - #[test] - fn should_not_enable_caching_with_pgbouncer() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); - assert_eq!(0, url.cache().capacity()); - } - - #[test] - fn should_parse_default_host() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("localhost", url.host()); - } - - #[test] - fn should_parse_ipv6_host() { - let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); - assert_eq!("2001:db8:1234::ffff", url.host()); - } - - #[test] - fn should_handle_options_field() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) - .unwrap(); - - assert_eq!("--cluster=my_cluster", url.options().unwrap()); - } - - #[tokio::test] - async fn test_custom_search_path_pg() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); - assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_pg_pgbouncer() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - url.query_pairs_mut().append_pair("pbbouncer", "true"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); - assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_crdb() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CRDB_CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("hello")); - assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("àbracadabra")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); - assert_eq!(test_path("héllo").await.as_deref(), Some("héllo")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("héll0$")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("héll_0$")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_unknown_pg() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Unknown); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("hello")); - assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn test_custom_search_path_unknown_crdb() { - async fn test_path(schema_name: &str) -> Option { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", schema_name); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Unknown); - - let client = PostgreSql::new(pg_url).await.unwrap(); - - let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); - let row = result_set.first().unwrap(); - - row[0].typed.to_string() - } - - // Safe - assert_eq!(test_path("hello").await.as_deref(), Some("hello")); - assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); - assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); - assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); - assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); - assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); - assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); - - // Not safe - assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); - assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); - assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); - assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); - assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); - assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); - assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); - assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); - assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); - assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); - assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); - assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); - assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); - assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); - - for ident in RESERVED_KEYWORDS { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); - } - } - - #[tokio::test] - async fn should_map_nonexisting_database_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_path("/this_does_not_exist"); - - let res = Quaint::new(url.as_str()).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::DatabaseDoesNotExist { db_name } => { - assert_eq!(Some("3D000"), e.original_code()); - assert_eq!( - Some("database \"this_does_not_exist\" does not exist"), - e.original_message() - ); - assert_eq!(&Name::available("this_does_not_exist"), db_name) - } - kind => panic!("Expected `DatabaseDoesNotExist`, got {:?}", kind), - }, - } - } - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("WRONG").unwrap(); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } - - #[tokio::test] - async fn should_map_tls_errors() { - let mut url = Url::parse(&CONN_STR).expect("parsing url"); - url.set_query(Some("sslmode=require&sslaccept=strict")); - - let res = Quaint::new(url.as_str()).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::TlsError { .. } => (), - other => panic!("{:#?}", other), - }, - } - } - - #[tokio::test] - async fn should_map_incorrect_parameters_error() { - let url = Url::parse(&CONN_STR).unwrap(); - let conn = Quaint::new(url.as_str()).await.unwrap(); - - let res = conn.query_raw("SELECT $1", &[Value::int32(1), Value::int32(2)]).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::IncorrectNumberOfParameters { expected, actual } => { - assert_eq!(1, *expected); - assert_eq!(2, *actual); - } - other => panic!("{:#?}", other), - }, - } - } - - #[test] - fn test_safe_ident() { - // Safe - assert!(is_safe_identifier("hello")); - assert!(is_safe_identifier("_hello")); - assert!(is_safe_identifier("àbracadabra")); - assert!(is_safe_identifier("h3ll0")); - assert!(is_safe_identifier("héllo")); - assert!(is_safe_identifier("héll0$")); - assert!(is_safe_identifier("héll_0$")); - assert!(is_safe_identifier("disconnect_security_must_honor_connect_scope_one2m")); - - // Not safe - assert!(!is_safe_identifier("")); - assert!(!is_safe_identifier("Hello")); - assert!(!is_safe_identifier("hEllo")); - assert!(!is_safe_identifier("$hello")); - assert!(!is_safe_identifier("hello!")); - assert!(!is_safe_identifier("hello#")); - assert!(!is_safe_identifier("he llo")); - assert!(!is_safe_identifier(" hello")); - assert!(!is_safe_identifier("he-llo")); - assert!(!is_safe_identifier("hÉllo")); - assert!(!is_safe_identifier("1337")); - assert!(!is_safe_identifier("_HELLO")); - assert!(!is_safe_identifier("HELLO")); - assert!(!is_safe_identifier("HELLO$")); - assert!(!is_safe_identifier("ÀBRACADABRA")); - - for ident in RESERVED_KEYWORDS { - assert!(!is_safe_identifier(ident)); - } - - for ident in RESERVED_TYPE_FUNCTION_NAMES { - assert!(!is_safe_identifier(ident)); - } - } - - #[test] - fn search_path_pgbouncer_should_be_set_with_query() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - url.query_pairs_mut().append_pair("pgbouncer", "true"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let config = pg_url.to_config(); - - // PGBouncer does not support the `search_path` connection parameter. - // When `pgbouncer=true`, config.search_path should be None, - // And the `search_path` should be set via a db query after connection. - assert_eq!(config.get_search_path(), None); - } - - #[test] - fn search_path_pg_should_be_set_with_param() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let config = pg_url.to_config(); - - // Postgres supports setting the search_path via a connection parameter. - assert_eq!(config.get_search_path(), Some(&"\"hello\"".to_owned())); - } - - #[test] - fn search_path_crdb_safe_ident_should_be_set_with_param() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let config = pg_url.to_config(); - - // CRDB supports setting the search_path via a connection parameter if the identifier is safe. - assert_eq!(config.get_search_path(), Some(&"hello".to_owned())); - } - - #[test] - fn search_path_crdb_unsafe_ident_should_be_set_with_query() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "HeLLo"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let config = pg_url.to_config(); - - // CRDB does NOT support setting the search_path via a connection parameter if the identifier is unsafe. - assert_eq!(config.get_search_path(), None); - } -} +#[cfg(feature = "postgresql-connector")] +pub(crate) mod native; diff --git a/quaint/src/connector/postgres/conversion.rs b/quaint/src/connector/postgres/native/conversion.rs similarity index 100% rename from quaint/src/connector/postgres/conversion.rs rename to quaint/src/connector/postgres/native/conversion.rs diff --git a/quaint/src/connector/postgres/conversion/decimal.rs b/quaint/src/connector/postgres/native/conversion/decimal.rs similarity index 100% rename from quaint/src/connector/postgres/conversion/decimal.rs rename to quaint/src/connector/postgres/native/conversion/decimal.rs diff --git a/quaint/src/connector/postgres/native/error.rs b/quaint/src/connector/postgres/native/error.rs new file mode 100644 index 000000000000..ec3b18483746 --- /dev/null +++ b/quaint/src/connector/postgres/native/error.rs @@ -0,0 +1,126 @@ +use tokio_postgres::error::DbError; + +use crate::{ + connector::error::PostgresError, + error::{Error, ErrorKind}, +}; + +impl From<&DbError> for PostgresError { + fn from(value: &DbError) -> Self { + PostgresError { + code: value.code().code().to_string(), + severity: value.severity().to_string(), + message: value.message().to_string(), + detail: value.detail().map(ToString::to_string), + column: value.column().map(ToString::to_string), + hint: value.hint().map(ToString::to_string), + } + } +} + +impl From for Error { + fn from(e: tokio_postgres::error::Error) -> Error { + if e.is_closed() { + return Error::builder(ErrorKind::ConnectionClosed).build(); + } + + if let Some(db_error) = e.as_db_error() { + return PostgresError::from(db_error).into(); + } + + if let Some(tls_error) = try_extracting_tls_error(&e) { + return tls_error; + } + + // Same for IO errors. + if let Some(io_error) = try_extracting_io_error(&e) { + return io_error; + } + + if let Some(uuid_error) = try_extracting_uuid_error(&e) { + return uuid_error; + } + + let reason = format!("{e}"); + let code = e.code().map(|c| c.code()); + + match reason.as_str() { + "error connecting to server: timed out" => { + let mut builder = Error::builder(ErrorKind::ConnectTimeout); + + if let Some(code) = code { + builder.set_original_code(code); + }; + + builder.set_original_message(reason); + builder.build() + } // sigh... + // https://github.com/sfackler/rust-postgres/blob/0c84ed9f8201f4e5b4803199a24afa2c9f3723b2/tokio-postgres/src/connect_tls.rs#L37 + "error performing TLS handshake: server does not support TLS" => { + let mut builder = Error::builder(ErrorKind::TlsError { + message: reason.clone(), + }); + + if let Some(code) = code { + builder.set_original_code(code); + }; + + builder.set_original_message(reason); + builder.build() + } // double sigh + _ => { + let code = code.map(|c| c.to_string()); + let mut builder = Error::builder(ErrorKind::QueryError(e.into())); + + if let Some(code) = code { + builder.set_original_code(code); + }; + + builder.set_original_message(reason); + builder.build() + } + } + } +} + +fn try_extracting_uuid_error(err: &tokio_postgres::error::Error) -> Option { + use std::error::Error as _; + + err.source() + .and_then(|err| err.downcast_ref::()) + .map(|err| ErrorKind::UUIDError(format!("{err}"))) + .map(|kind| Error::builder(kind).build()) +} + +fn try_extracting_tls_error(err: &tokio_postgres::error::Error) -> Option { + use std::error::Error; + + err.source() + .and_then(|err| err.downcast_ref::()) + .map(|err| err.into()) +} + +fn try_extracting_io_error(err: &tokio_postgres::error::Error) -> Option { + use std::error::Error as _; + + err.source() + .and_then(|err| err.downcast_ref::()) + .map(|err| ErrorKind::ConnectionError(Box::new(std::io::Error::new(err.kind(), format!("{err}"))))) + .map(|kind| Error::builder(kind).build()) +} + +impl From for Error { + fn from(e: native_tls::Error) -> Error { + Error::from(&e) + } +} + +impl From<&native_tls::Error> for Error { + fn from(e: &native_tls::Error) -> Error { + let kind = ErrorKind::TlsError { + message: format!("{e}"), + }; + + Error::builder(kind).build() + } +} diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs new file mode 100644 index 000000000000..8f1645ca4123 --- /dev/null +++ b/quaint/src/connector/postgres/native/mod.rs @@ -0,0 +1,1184 @@ +///! Definitions for the Postgres connector. +/// This module is not compatible with wasm32-* targets. +/// This module is only available with the `postgresql-connector` feature. +mod conversion; +mod error; + +use crate::connector::postgres::wasm::common::{Hidden, SslAcceptMode, SslParams}; +pub(crate) use crate::connector::postgres::wasm::common::{PostgresFlavour, PostgresUrl}; +use crate::connector::{timeout, IsolationLevel, Transaction}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, ResultSet}, + error::{Error, ErrorKind}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use futures::{future::FutureExt, lock::Mutex}; +use lru_cache::LruCache; +use native_tls::{Certificate, Identity, TlsConnector}; +use postgres_native_tls::MakeTlsConnector; +use std::{ + borrow::Borrow, + fmt::{Debug, Display}, + fs, + future::Future, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use tokio_postgres::{config::ChannelBinding, Client, Config, Statement}; + +/// The underlying postgres driver. Only available with the `expose-drivers` +/// Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use tokio_postgres; + +struct PostgresClient(Client); + +impl Debug for PostgresClient { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("PostgresClient") + } +} + +/// A connector interface for the PostgreSQL database. +#[derive(Debug)] +pub struct PostgreSql { + client: PostgresClient, + pg_bouncer: bool, + socket_timeout: Option, + statement_cache: Mutex>, + is_healthy: AtomicBool, +} + +#[derive(Debug)] +struct SslAuth { + certificate: Hidden>, + identity: Hidden>, + ssl_accept_mode: SslAcceptMode, +} + +impl Default for SslAuth { + fn default() -> Self { + Self { + certificate: Hidden(None), + identity: Hidden(None), + ssl_accept_mode: SslAcceptMode::AcceptInvalidCerts, + } + } +} + +impl SslAuth { + fn certificate(&mut self, certificate: Certificate) -> &mut Self { + self.certificate = Hidden(Some(certificate)); + self + } + + fn identity(&mut self, identity: Identity) -> &mut Self { + self.identity = Hidden(Some(identity)); + self + } + + fn accept_mode(&mut self, mode: SslAcceptMode) -> &mut Self { + self.ssl_accept_mode = mode; + self + } +} + +impl SslParams { + async fn into_auth(self) -> crate::Result { + let mut auth = SslAuth::default(); + auth.accept_mode(self.ssl_accept_mode); + + if let Some(ref cert_file) = self.certificate_file { + let cert = fs::read(cert_file).map_err(|err| { + Error::builder(ErrorKind::TlsError { + message: format!("cert file not found ({err})"), + }) + .build() + })?; + + auth.certificate(Certificate::from_pem(&cert)?); + } + + if let Some(ref identity_file) = self.identity_file { + let db = fs::read(identity_file).map_err(|err| { + Error::builder(ErrorKind::TlsError { + message: format!("identity file not found ({err})"), + }) + .build() + })?; + let password = self.identity_password.0.as_deref().unwrap_or(""); + let identity = Identity::from_pkcs12(&db, password)?; + + auth.identity(identity); + } + + Ok(auth) + } +} + +impl PostgresUrl { + pub(crate) fn cache(&self) -> LruCache { + if self.query_params.pg_bouncer { + LruCache::new(0) + } else { + LruCache::new(self.query_params.statement_cache_size) + } + } + + pub fn channel_binding(&self) -> ChannelBinding { + self.query_params.channel_binding + } + + /// On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. + /// We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. + /// To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. + fn set_search_path(&self, config: &mut Config) { + // PGBouncer does not support the search_path connection parameter. + // https://www.pgbouncer.org/config.html#ignore_startup_parameters + if self.query_params.pg_bouncer { + return; + } + + if let Some(schema) = &self.query_params.schema { + if self.flavour().is_cockroach() && is_safe_identifier(schema) { + config.search_path(CockroachSearchPath(schema).to_string()); + } + + if self.flavour().is_postgres() { + config.search_path(PostgresSearchPath(schema).to_string()); + } + } + } + + pub(crate) fn to_config(&self) -> Config { + let mut config = Config::new(); + + config.user(self.username().borrow()); + config.password(self.password().borrow() as &str); + config.host(self.host()); + config.port(self.port()); + config.dbname(self.dbname()); + config.pgbouncer_mode(self.query_params.pg_bouncer); + + if let Some(options) = self.options() { + config.options(options); + } + + if let Some(application_name) = self.application_name() { + config.application_name(application_name); + } + + if let Some(connect_timeout) = self.query_params.connect_timeout { + config.connect_timeout(connect_timeout); + } + + self.set_search_path(&mut config); + + config.ssl_mode(self.query_params.ssl_mode); + + config.channel_binding(self.query_params.channel_binding); + + config + } +} + +impl PostgreSql { + /// Create a new connection to the database. + pub async fn new(url: PostgresUrl) -> crate::Result { + let config = url.to_config(); + + let mut tls_builder = TlsConnector::builder(); + + { + let ssl_params = url.ssl_params(); + let auth = ssl_params.to_owned().into_auth().await?; + + if let Some(certificate) = auth.certificate.0 { + tls_builder.add_root_certificate(certificate); + } + + tls_builder.danger_accept_invalid_certs(auth.ssl_accept_mode == SslAcceptMode::AcceptInvalidCerts); + + if let Some(identity) = auth.identity.0 { + tls_builder.identity(identity); + } + } + + let tls = MakeTlsConnector::new(tls_builder.build()?); + let (client, conn) = timeout::connect(url.connect_timeout(), config.connect(tls)).await?; + + tokio::spawn(conn.map(|r| match r { + Ok(_) => (), + Err(e) => { + tracing::error!("Error in PostgreSQL connection: {:?}", e); + } + })); + + // On Postgres, we set the SEARCH_PATH and client-encoding through client connection parameters to save a network roundtrip on connection. + // We can't always do it for CockroachDB because it does not expect quotes for unsafe identifiers (https://github.com/cockroachdb/cockroach/issues/101328), which might change once the issue is fixed. + // To circumvent that problem, we only set the SEARCH_PATH through client connection parameters for Cockroach when the identifier is safe, so that the quoting does not matter. + // Finally, to ensure backward compatibility, we keep sending a database query in case the flavour is set to Unknown. + if let Some(schema) = &url.query_params.schema { + // PGBouncer does not support the search_path connection parameter. + // https://www.pgbouncer.org/config.html#ignore_startup_parameters + if url.query_params.pg_bouncer + || url.flavour().is_unknown() + || (url.flavour().is_cockroach() && !is_safe_identifier(schema)) + { + let session_variables = format!( + r##"{set_search_path}"##, + set_search_path = SetSearchPath(url.query_params.schema.as_deref()) + ); + + client.simple_query(session_variables.as_str()).await?; + } + } + + Ok(Self { + client: PostgresClient(client), + socket_timeout: url.query_params.socket_timeout, + pg_bouncer: url.query_params.pg_bouncer, + statement_cache: Mutex::new(url.cache()), + is_healthy: AtomicBool::new(true), + }) + } + + /// The underlying tokio_postgres::Client. Only available with the + /// `expose-drivers` Cargo feature. This is a lower level API when you need + /// to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn client(&self) -> &tokio_postgres::Client { + &self.client.0 + } + + async fn fetch_cached(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + let mut cache = self.statement_cache.lock().await; + let capacity = cache.capacity(); + let stored = cache.len(); + + match cache.get_mut(sql) { + Some(stmt) => { + tracing::trace!( + message = "CACHE HIT!", + query = sql, + capacity = capacity, + stored = stored, + ); + + Ok(stmt.clone()) // arc'd + } + None => { + tracing::trace!( + message = "CACHE MISS!", + query = sql, + capacity = capacity, + stored = stored, + ); + + let param_types = conversion::params_to_types(params); + let stmt = self.perform_io(self.client.0.prepare_typed(sql, ¶m_types)).await?; + + cache.insert(sql.to_string(), stmt.clone()); + + Ok(stmt) + } + } + } + + async fn perform_io(&self, fut: F) -> crate::Result + where + F: Future>, + { + match timeout::socket(self.socket_timeout, fut).await { + Err(e) if e.is_closed() => { + self.is_healthy.store(false, Ordering::SeqCst); + Err(e) + } + res => res, + } + } + + fn check_bind_variables_len(&self, params: &[Value<'_>]) -> crate::Result<()> { + if params.len() > i16::MAX as usize { + // tokio_postgres would return an error here. Let's avoid calling the driver + // and return an error early. + let kind = ErrorKind::QueryInvalidInput(format!( + "too many bind variables in prepared statement, expected maximum of {}, received {}", + i16::MAX, + params.len() + )); + Err(Error::builder(kind).build()) + } else { + Ok(()) + } + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct CockroachSearchPath<'a>(&'a str); + +impl Display for CockroachSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.0) + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct PostgresSearchPath<'a>(&'a str); + +impl Display for PostgresSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("\"")?; + f.write_str(self.0)?; + f.write_str("\"")?; + + Ok(()) + } +} + +// A SetSearchPath statement (Display-impl) for connection initialization. +struct SetSearchPath<'a>(Option<&'a str>); + +impl Display for SetSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(schema) = self.0 { + f.write_str("SET search_path = \"")?; + f.write_str(schema)?; + f.write_str("\";\n")?; + } + + Ok(()) + } +} + +impl_default_TransactionCapable!(PostgreSql); + +#[async_trait] +impl Queryable for PostgreSql { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Postgres::build(q)?; + + self.query_raw(sql.as_str(), ¶ms[..]).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.query_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, &[]).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let rows = self + .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); + + for row in rows { + result.rows.push(row.get_result_row()?); + } + + Ok(result) + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.query_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, params).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let rows = self + .perform_io(self.client.0.query(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + let mut result = ResultSet::new(stmt.to_column_names(), Vec::new()); + + for row in rows { + result.rows.push(row.get_result_row()?); + } + + Ok(result) + }) + .await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Postgres::build(q)?; + + self.execute_raw(sql.as_str(), ¶ms[..]).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.execute_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, &[]).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let changes = self + .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + Ok(changes) + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.check_bind_variables_len(params)?; + + metrics::query("postgres.execute_raw", sql, params, move || async move { + let stmt = self.fetch_cached(sql, params).await?; + + if stmt.params().len() != params.len() { + let kind = ErrorKind::IncorrectNumberOfParameters { + expected: stmt.params().len(), + actual: params.len(), + }; + + return Err(Error::builder(kind).build()); + } + + let changes = self + .perform_io(self.client.0.execute(&stmt, conversion::conv_params(params).as_slice())) + .await?; + + Ok(changes) + }) + .await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("postgres.raw_cmd", cmd, &[], move || async move { + self.perform_io(self.client.0.simple_query(cmd)).await?; + Ok(()) + }) + .await + } + + async fn version(&self) -> crate::Result> { + let query = r#"SELECT version()"#; + let rows = self.query_raw(query, &[]).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.to_string())); + + Ok(version_string) + } + + fn is_healthy(&self) -> bool { + self.is_healthy.load(Ordering::SeqCst) + } + + async fn server_reset_query(&self, tx: &dyn Transaction) -> crate::Result<()> { + if self.pg_bouncer { + tx.raw_cmd("DEALLOCATE ALL").await + } else { + Ok(()) + } + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + if matches!(isolation_level, IsolationLevel::Snapshot) { + return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); + } + + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await?; + + Ok(()) + } + + fn requires_isolation_first(&self) -> bool { + false + } +} + +/// Sorted list of CockroachDB's reserved keywords. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_KEYWORDS: [&str; 79] = [ + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "both", + "case", + "cast", + "check", + "collate", + "column", + "concurrently", + "constraint", + "create", + "current_catalog", + "current_date", + "current_role", + "current_schema", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "fetch", + "for", + "foreign", + "from", + "grant", + "group", + "having", + "in", + "initially", + "intersect", + "into", + "lateral", + "leading", + "limit", + "localtime", + "localtimestamp", + "not", + "null", + "offset", + "on", + "only", + "or", + "order", + "placing", + "primary", + "references", + "returning", + "select", + "session_user", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "variadic", + "when", + "where", + "window", + "with", +]; + +/// Sorted list of CockroachDB's reserved type function names. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ + "authorization", + "collation", + "cross", + "full", + "ilike", + "inner", + "is", + "isnull", + "join", + "left", + "like", + "natural", + "none", + "notnull", + "outer", + "overlaps", + "right", + "similar", +]; + +/// Returns true if a Postgres identifier is considered "safe". +/// +/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. +/// +/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers +fn is_safe_identifier(ident: &str) -> bool { + if ident.is_empty() { + return false; + } + + // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. + if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { + return false; + } + + let mut chars = ident.chars(); + + let first = chars.next().unwrap(); + + // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). + if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { + return false; + } + + for c in chars { + // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). + if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { + return false; + } + } + + true +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tests::test_api::postgres::CONN_STR; + use crate::tests::test_api::CRDB_CONN_STR; + use crate::{connector::Queryable, error::*, single::Quaint}; + use url::Url; + + #[test] + fn should_parse_socket_url() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("/var/run/psql.sock", url.host()); + } + + #[test] + fn should_parse_escaped_url() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("/var/run/postgresql", url.host()); + } + + #[test] + fn should_allow_changing_of_cache_size() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); + assert_eq!(420, url.cache().capacity()); + } + + #[test] + fn should_have_default_cache_size() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + assert_eq!(100, url.cache().capacity()); + } + + #[test] + fn should_have_application_name() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); + assert_eq!(Some("test"), url.application_name()); + } + + #[test] + fn should_have_channel_binding() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Require, url.channel_binding()); + } + + #[test] + fn should_have_default_channel_binding() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Prefer, url.channel_binding()); + + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Prefer, url.channel_binding()); + } + + #[test] + fn should_not_enable_caching_with_pgbouncer() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); + assert_eq!(0, url.cache().capacity()); + } + + #[test] + fn should_parse_default_host() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("localhost", url.host()); + } + + #[test] + fn should_parse_ipv6_host() { + let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); + assert_eq!("2001:db8:1234::ffff", url.host()); + } + + #[test] + fn should_handle_options_field() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) + .unwrap(); + + assert_eq!("--cluster=my_cluster", url.options().unwrap()); + } + + #[tokio::test] + async fn test_custom_search_path_pg() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); + assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_pg_pgbouncer() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + url.query_pairs_mut().append_pair("pbbouncer", "true"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("\"hello\"")); + assert_eq!(test_path("_hello").await.as_deref(), Some("\"_hello\"")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("\"h3ll0\"")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_crdb() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CRDB_CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("hello")); + assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("àbracadabra")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); + assert_eq!(test_path("héllo").await.as_deref(), Some("héllo")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("héll0$")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("héll_0$")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_unknown_pg() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Unknown); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("hello")); + assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn test_custom_search_path_unknown_crdb() { + async fn test_path(schema_name: &str) -> Option { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", schema_name); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Unknown); + + let client = PostgreSql::new(pg_url).await.unwrap(); + + let result_set = client.query_raw("SHOW search_path", &[]).await.unwrap(); + let row = result_set.first().unwrap(); + + row[0].typed.to_string() + } + + // Safe + assert_eq!(test_path("hello").await.as_deref(), Some("hello")); + assert_eq!(test_path("_hello").await.as_deref(), Some("_hello")); + assert_eq!(test_path("àbracadabra").await.as_deref(), Some("\"àbracadabra\"")); + assert_eq!(test_path("h3ll0").await.as_deref(), Some("h3ll0")); + assert_eq!(test_path("héllo").await.as_deref(), Some("\"héllo\"")); + assert_eq!(test_path("héll0$").await.as_deref(), Some("\"héll0$\"")); + assert_eq!(test_path("héll_0$").await.as_deref(), Some("\"héll_0$\"")); + + // Not safe + assert_eq!(test_path("Hello").await.as_deref(), Some("\"Hello\"")); + assert_eq!(test_path("hEllo").await.as_deref(), Some("\"hEllo\"")); + assert_eq!(test_path("$hello").await.as_deref(), Some("\"$hello\"")); + assert_eq!(test_path("hello!").await.as_deref(), Some("\"hello!\"")); + assert_eq!(test_path("hello#").await.as_deref(), Some("\"hello#\"")); + assert_eq!(test_path("he llo").await.as_deref(), Some("\"he llo\"")); + assert_eq!(test_path(" hello").await.as_deref(), Some("\" hello\"")); + assert_eq!(test_path("he-llo").await.as_deref(), Some("\"he-llo\"")); + assert_eq!(test_path("hÉllo").await.as_deref(), Some("\"hÉllo\"")); + assert_eq!(test_path("1337").await.as_deref(), Some("\"1337\"")); + assert_eq!(test_path("_HELLO").await.as_deref(), Some("\"_HELLO\"")); + assert_eq!(test_path("HELLO").await.as_deref(), Some("\"HELLO\"")); + assert_eq!(test_path("HELLO$").await.as_deref(), Some("\"HELLO$\"")); + assert_eq!(test_path("ÀBRACADABRA").await.as_deref(), Some("\"ÀBRACADABRA\"")); + + for ident in RESERVED_KEYWORDS { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert_eq!(test_path(ident).await.as_deref(), Some(format!("\"{ident}\"").as_str())); + } + } + + #[tokio::test] + async fn should_map_nonexisting_database_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_path("/this_does_not_exist"); + + let res = Quaint::new(url.as_str()).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::DatabaseDoesNotExist { db_name } => { + assert_eq!(Some("3D000"), e.original_code()); + assert_eq!( + Some("database \"this_does_not_exist\" does not exist"), + e.original_message() + ); + assert_eq!(&Name::available("this_does_not_exist"), db_name) + } + kind => panic!("Expected `DatabaseDoesNotExist`, got {:?}", kind), + }, + } + } + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("WRONG").unwrap(); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } + + #[tokio::test] + async fn should_map_tls_errors() { + let mut url = Url::parse(&CONN_STR).expect("parsing url"); + url.set_query(Some("sslmode=require&sslaccept=strict")); + + let res = Quaint::new(url.as_str()).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::TlsError { .. } => (), + other => panic!("{:#?}", other), + }, + } + } + + #[tokio::test] + async fn should_map_incorrect_parameters_error() { + let url = Url::parse(&CONN_STR).unwrap(); + let conn = Quaint::new(url.as_str()).await.unwrap(); + + let res = conn.query_raw("SELECT $1", &[Value::int32(1), Value::int32(2)]).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::IncorrectNumberOfParameters { expected, actual } => { + assert_eq!(1, *expected); + assert_eq!(2, *actual); + } + other => panic!("{:#?}", other), + }, + } + } + + #[test] + fn test_safe_ident() { + // Safe + assert!(is_safe_identifier("hello")); + assert!(is_safe_identifier("_hello")); + assert!(is_safe_identifier("àbracadabra")); + assert!(is_safe_identifier("h3ll0")); + assert!(is_safe_identifier("héllo")); + assert!(is_safe_identifier("héll0$")); + assert!(is_safe_identifier("héll_0$")); + assert!(is_safe_identifier("disconnect_security_must_honor_connect_scope_one2m")); + + // Not safe + assert!(!is_safe_identifier("")); + assert!(!is_safe_identifier("Hello")); + assert!(!is_safe_identifier("hEllo")); + assert!(!is_safe_identifier("$hello")); + assert!(!is_safe_identifier("hello!")); + assert!(!is_safe_identifier("hello#")); + assert!(!is_safe_identifier("he llo")); + assert!(!is_safe_identifier(" hello")); + assert!(!is_safe_identifier("he-llo")); + assert!(!is_safe_identifier("hÉllo")); + assert!(!is_safe_identifier("1337")); + assert!(!is_safe_identifier("_HELLO")); + assert!(!is_safe_identifier("HELLO")); + assert!(!is_safe_identifier("HELLO$")); + assert!(!is_safe_identifier("ÀBRACADABRA")); + + for ident in RESERVED_KEYWORDS { + assert!(!is_safe_identifier(ident)); + } + + for ident in RESERVED_TYPE_FUNCTION_NAMES { + assert!(!is_safe_identifier(ident)); + } + } + + #[test] + fn search_path_pgbouncer_should_be_set_with_query() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + url.query_pairs_mut().append_pair("pgbouncer", "true"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let config = pg_url.to_config(); + + // PGBouncer does not support the `search_path` connection parameter. + // When `pgbouncer=true`, config.search_path should be None, + // And the `search_path` should be set via a db query after connection. + assert_eq!(config.get_search_path(), None); + } + + #[test] + fn search_path_pg_should_be_set_with_param() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let config = pg_url.to_config(); + + // Postgres supports setting the search_path via a connection parameter. + assert_eq!(config.get_search_path(), Some(&"\"hello\"".to_owned())); + } + + #[test] + fn search_path_crdb_safe_ident_should_be_set_with_param() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let config = pg_url.to_config(); + + // CRDB supports setting the search_path via a connection parameter if the identifier is safe. + assert_eq!(config.get_search_path(), Some(&"hello".to_owned())); + } + + #[test] + fn search_path_crdb_unsafe_ident_should_be_set_with_query() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "HeLLo"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let config = pg_url.to_config(); + + // CRDB does NOT support setting the search_path via a connection parameter if the identifier is unsafe. + assert_eq!(config.get_search_path(), None); + } +} diff --git a/quaint/src/connector/postgres/wasm/common.rs b/quaint/src/connector/postgres/wasm/common.rs new file mode 100644 index 000000000000..46d327c0183d --- /dev/null +++ b/quaint/src/connector/postgres/wasm/common.rs @@ -0,0 +1,612 @@ +use std::{ + borrow::Cow, + fmt::{Debug, Display}, + time::Duration, +}; + +use percent_encoding::percent_decode; +use url::{Host, Url}; + +use crate::error::{Error, ErrorKind}; + +#[cfg(feature = "postgresql-connector")] +use tokio_postgres::config::{ChannelBinding, SslMode}; + +#[derive(Clone)] +pub(crate) struct Hidden(pub(crate) T); + +impl Debug for Hidden { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("") + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SslAcceptMode { + Strict, + AcceptInvalidCerts, +} + +#[derive(Debug, Clone)] +pub struct SslParams { + pub(crate) certificate_file: Option, + pub(crate) identity_file: Option, + pub(crate) identity_password: Hidden>, + pub(crate) ssl_accept_mode: SslAcceptMode, +} + +#[derive(Debug, Clone, Copy)] +pub enum PostgresFlavour { + Postgres, + Cockroach, + Unknown, +} + +impl PostgresFlavour { + /// Returns `true` if the postgres flavour is [`Postgres`]. + /// + /// [`Postgres`]: PostgresFlavour::Postgres + pub(crate) fn is_postgres(&self) -> bool { + matches!(self, Self::Postgres) + } + + /// Returns `true` if the postgres flavour is [`Cockroach`]. + /// + /// [`Cockroach`]: PostgresFlavour::Cockroach + pub(crate) fn is_cockroach(&self) -> bool { + matches!(self, Self::Cockroach) + } + + /// Returns `true` if the postgres flavour is [`Unknown`]. + /// + /// [`Unknown`]: PostgresFlavour::Unknown + pub(crate) fn is_unknown(&self) -> bool { + matches!(self, Self::Unknown) + } +} + +/// Wraps a connection url and exposes the parsing logic used by Quaint, +/// including default values. +#[derive(Debug, Clone)] +pub struct PostgresUrl { + pub(crate) url: Url, + pub(crate) query_params: PostgresUrlQueryParams, + pub(crate) flavour: PostgresFlavour, +} + +pub(crate) const DEFAULT_SCHEMA: &str = "public"; + +impl PostgresUrl { + /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection + /// parameters. + pub fn new(url: Url) -> Result { + let query_params = Self::parse_query_params(&url)?; + + Ok(Self { + url, + query_params, + flavour: PostgresFlavour::Unknown, + }) + } + + /// The bare `Url` to the database. + pub fn url(&self) -> &Url { + &self.url + } + + /// The percent-decoded database username. + pub fn username(&self) -> Cow { + match percent_decode(self.url.username().as_bytes()).decode_utf8() { + Ok(username) => username, + Err(_) => { + tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + + self.url.username().into() + } + } + } + + /// The database host. Taken first from the `host` query parameter, then + /// from the `host` part of the URL. For socket connections, the query + /// parameter must be used. + /// + /// If none of them are set, defaults to `localhost`. + pub fn host(&self) -> &str { + match (self.query_params.host.as_ref(), self.url.host_str(), self.url.host()) { + (Some(host), _, _) => host.as_str(), + (None, Some(""), _) => "localhost", + (None, None, _) => "localhost", + (None, Some(host), Some(Host::Ipv6(_))) => { + // The `url` crate may return an IPv6 address in brackets, which must be stripped. + if host.starts_with('[') && host.ends_with(']') { + &host[1..host.len() - 1] + } else { + host + } + } + (None, Some(host), _) => host, + } + } + + /// Name of the database connected. Defaults to `postgres`. + pub fn dbname(&self) -> &str { + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("postgres"), + None => "postgres", + } + } + + /// The percent-decoded database password. + pub fn password(&self) -> Cow { + match self + .url + .password() + .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) + { + Some(password) => password, + None => self.url.password().unwrap_or("").into(), + } + } + + /// The database port, defaults to `5432`. + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(5432) + } + + /// The database schema, defaults to `public`. + pub fn schema(&self) -> &str { + self.query_params.schema.as_deref().unwrap_or(DEFAULT_SCHEMA) + } + + /// Whether the pgbouncer mode is enabled. + pub fn pg_bouncer(&self) -> bool { + self.query_params.pg_bouncer + } + + /// The connection timeout. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout + } + + /// Pool check_out timeout + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout + } + + /// The socket timeout + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime + } + + /// The custom application name + pub fn application_name(&self) -> Option<&str> { + self.query_params.application_name.as_deref() + } + + pub(crate) fn options(&self) -> Option<&str> { + self.query_params.options.as_deref() + } + + /// Sets whether the URL points to a Postgres, Cockroach or Unknown database. + /// This is used to avoid a network roundtrip at connection to set the search path. + /// + /// The different behaviours are: + /// - Postgres: Always avoid a network roundtrip by setting the search path through client connection parameters. + /// - Cockroach: Avoid a network roundtrip if the schema name is deemed "safe" (i.e. no escape quoting required). Otherwise, set the search path through a database query. + /// - Unknown: Always add a network roundtrip by setting the search path through a database query. + pub fn set_flavour(&mut self, flavour: PostgresFlavour) { + self.flavour = flavour; + } + + fn parse_query_params(url: &Url) -> Result { + #[cfg(feature = "postgresql-connector")] + let mut ssl_mode = SslMode::Prefer; + #[cfg(feature = "postgresql-connector")] + let mut channel_binding = ChannelBinding::Prefer; + + let mut connection_limit = None; + let mut schema = None; + let mut certificate_file = None; + let mut identity_file = None; + let mut identity_password = None; + let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; + let mut host = None; + let mut application_name = None; + let mut socket_timeout = None; + let mut connect_timeout = Some(Duration::from_secs(5)); + let mut pool_timeout = Some(Duration::from_secs(10)); + let mut pg_bouncer = false; + let mut statement_cache_size = 100; + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); + let mut options = None; + + for (k, v) in url.query_pairs() { + match k.as_ref() { + "pgbouncer" => { + pg_bouncer = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + #[cfg(feature = "postgresql-connector")] + "sslmode" => { + match v.as_ref() { + "disable" => ssl_mode = SslMode::Disable, + "prefer" => ssl_mode = SslMode::Prefer, + "require" => ssl_mode = SslMode::Require, + _ => { + tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); + } + }; + } + "sslcert" => { + certificate_file = Some(v.to_string()); + } + "sslidentity" => { + identity_file = Some(v.to_string()); + } + "sslpassword" => { + identity_password = Some(v.to_string()); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + "sslaccept" => { + match v.as_ref() { + "strict" => { + ssl_accept_mode = SslAcceptMode::Strict; + } + "accept_invalid_certs" => { + ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; + } + _ => { + tracing::debug!( + message = "Unsupported SSL accept mode, defaulting to `strict`", + mode = &*v + ); + + ssl_accept_mode = SslAcceptMode::Strict; + } + }; + } + "schema" => { + schema = Some(v.to_string()); + } + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + connection_limit = Some(as_int); + } + "host" => { + host = Some(v.to_string()); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + socket_timeout = Some(Duration::from_secs(as_int)); + } + "connect_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + connect_timeout = None; + } else { + connect_timeout = Some(Duration::from_secs(as_int)); + } + } + "pool_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + pool_timeout = None; + } else { + pool_timeout = Some(Duration::from_secs(as_int)); + } + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "application_name" => { + application_name = Some(v.to_string()); + } + #[cfg(feature = "postgresql-connector")] + "channel_binding" => { + match v.as_ref() { + "disable" => channel_binding = ChannelBinding::Disable, + "prefer" => channel_binding = ChannelBinding::Prefer, + "require" => channel_binding = ChannelBinding::Require, + _ => { + tracing::debug!( + message = "Unsupported Channel Binding {channel_binding}, defaulting to `prefer`", + channel_binding = &*v + ); + } + }; + } + "options" => { + options = Some(v.to_string()); + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = &*k); + } + }; + } + + Ok(PostgresUrlQueryParams { + ssl_params: SslParams { + certificate_file, + identity_file, + ssl_accept_mode, + identity_password: Hidden(identity_password), + }, + connection_limit, + schema, + host, + connect_timeout, + pool_timeout, + socket_timeout, + pg_bouncer, + statement_cache_size, + max_connection_lifetime, + max_idle_connection_lifetime, + application_name, + options, + #[cfg(feature = "postgresql-connector")] + channel_binding, + #[cfg(feature = "postgresql-connector")] + ssl_mode, + }) + } + + pub(crate) fn ssl_params(&self) -> &SslParams { + &self.query_params.ssl_params + } + + #[cfg(feature = "pooled")] + pub(crate) fn connection_limit(&self) -> Option { + self.query_params.connection_limit + } + + pub fn flavour(&self) -> PostgresFlavour { + self.flavour + } +} + +#[derive(Debug, Clone)] +pub(crate) struct PostgresUrlQueryParams { + pub(crate) ssl_params: SslParams, + pub(crate) connection_limit: Option, + pub(crate) schema: Option, + pub(crate) pg_bouncer: bool, + pub(crate) host: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) statement_cache_size: usize, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, + pub(crate) application_name: Option, + pub(crate) options: Option, + + #[cfg(feature = "postgresql-connector")] + pub(crate) channel_binding: ChannelBinding, + + #[cfg(feature = "postgresql-connector")] + pub(crate) ssl_mode: SslMode, +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct CockroachSearchPath<'a>(&'a str); + +impl Display for CockroachSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.0) + } +} + +// A SearchPath connection parameter (Display-impl) for connection initialization. +struct PostgresSearchPath<'a>(&'a str); + +impl Display for PostgresSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("\"")?; + f.write_str(self.0)?; + f.write_str("\"")?; + + Ok(()) + } +} + +// A SetSearchPath statement (Display-impl) for connection initialization. +struct SetSearchPath<'a>(Option<&'a str>); + +impl Display for SetSearchPath<'_> { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + if let Some(schema) = self.0 { + f.write_str("SET search_path = \"")?; + f.write_str(schema)?; + f.write_str("\";\n")?; + } + + Ok(()) + } +} + +/// Sorted list of CockroachDB's reserved keywords. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_KEYWORDS: [&str; 79] = [ + "all", + "analyse", + "analyze", + "and", + "any", + "array", + "as", + "asc", + "asymmetric", + "both", + "case", + "cast", + "check", + "collate", + "column", + "concurrently", + "constraint", + "create", + "current_catalog", + "current_date", + "current_role", + "current_schema", + "current_time", + "current_timestamp", + "current_user", + "default", + "deferrable", + "desc", + "distinct", + "do", + "else", + "end", + "except", + "false", + "fetch", + "for", + "foreign", + "from", + "grant", + "group", + "having", + "in", + "initially", + "intersect", + "into", + "lateral", + "leading", + "limit", + "localtime", + "localtimestamp", + "not", + "null", + "offset", + "on", + "only", + "or", + "order", + "placing", + "primary", + "references", + "returning", + "select", + "session_user", + "some", + "symmetric", + "table", + "then", + "to", + "trailing", + "true", + "union", + "unique", + "user", + "using", + "variadic", + "when", + "where", + "window", + "with", +]; + +/// Sorted list of CockroachDB's reserved type function names. +/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords +const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ + "authorization", + "collation", + "cross", + "full", + "ilike", + "inner", + "is", + "isnull", + "join", + "left", + "like", + "natural", + "none", + "notnull", + "outer", + "overlaps", + "right", + "similar", +]; + +/// Returns true if a Postgres identifier is considered "safe". +/// +/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. +/// +/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS +/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers +fn is_safe_identifier(ident: &str) -> bool { + if ident.is_empty() { + return false; + } + + // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. + if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { + return false; + } + + let mut chars = ident.chars(); + + let first = chars.next().unwrap(); + + // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). + if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { + return false; + } + + for c in chars { + // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). + if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { + return false; + } + } + + true +} diff --git a/quaint/src/connector/postgres/error.rs b/quaint/src/connector/postgres/wasm/error.rs similarity index 66% rename from quaint/src/connector/postgres/error.rs rename to quaint/src/connector/postgres/wasm/error.rs index d4e5ec7837fe..ab6ec7b07847 100644 --- a/quaint/src/connector/postgres/error.rs +++ b/quaint/src/connector/postgres/wasm/error.rs @@ -1,7 +1,5 @@ use std::fmt::{Display, Formatter}; -use tokio_postgres::error::DbError; - use crate::error::{DatabaseConstraint, Error, ErrorKind, Name}; #[derive(Debug)] @@ -17,7 +15,7 @@ pub struct PostgresError { impl std::error::Error for PostgresError {} impl Display for PostgresError { - // copy of DbError::fmt + // copy of tokio_postgres::error::DbError::fmt fn fmt(&self, fmt: &mut Formatter<'_>) -> std::fmt::Result { write!(fmt, "{}: {}", self.severity, self.message)?; if let Some(detail) = &self.detail { @@ -30,19 +28,6 @@ impl Display for PostgresError { } } -impl From<&DbError> for PostgresError { - fn from(value: &DbError) -> Self { - PostgresError { - code: value.code().code().to_string(), - severity: value.severity().to_string(), - message: value.message().to_string(), - detail: value.detail().map(ToString::to_string), - column: value.column().map(ToString::to_string), - hint: value.hint().map(ToString::to_string), - } - } -} - impl From for Error { fn from(value: PostgresError) -> Self { match value.code.as_str() { @@ -245,110 +230,3 @@ impl From for Error { } } } - -impl From for Error { - fn from(e: tokio_postgres::error::Error) -> Error { - if e.is_closed() { - return Error::builder(ErrorKind::ConnectionClosed).build(); - } - - if let Some(db_error) = e.as_db_error() { - return PostgresError::from(db_error).into(); - } - - if let Some(tls_error) = try_extracting_tls_error(&e) { - return tls_error; - } - - // Same for IO errors. - if let Some(io_error) = try_extracting_io_error(&e) { - return io_error; - } - - if let Some(uuid_error) = try_extracting_uuid_error(&e) { - return uuid_error; - } - - let reason = format!("{e}"); - let code = e.code().map(|c| c.code()); - - match reason.as_str() { - "error connecting to server: timed out" => { - let mut builder = Error::builder(ErrorKind::ConnectTimeout); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } // sigh... - // https://github.com/sfackler/rust-postgres/blob/0c84ed9f8201f4e5b4803199a24afa2c9f3723b2/tokio-postgres/src/connect_tls.rs#L37 - "error performing TLS handshake: server does not support TLS" => { - let mut builder = Error::builder(ErrorKind::TlsError { - message: reason.clone(), - }); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } // double sigh - _ => { - let code = code.map(|c| c.to_string()); - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - - if let Some(code) = code { - builder.set_original_code(code); - }; - - builder.set_original_message(reason); - builder.build() - } - } - } -} - -fn try_extracting_uuid_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error as _; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| ErrorKind::UUIDError(format!("{err}"))) - .map(|kind| Error::builder(kind).build()) -} - -fn try_extracting_tls_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| err.into()) -} - -fn try_extracting_io_error(err: &tokio_postgres::error::Error) -> Option { - use std::error::Error as _; - - err.source() - .and_then(|err| err.downcast_ref::()) - .map(|err| ErrorKind::ConnectionError(Box::new(std::io::Error::new(err.kind(), format!("{err}"))))) - .map(|kind| Error::builder(kind).build()) -} - -impl From for Error { - fn from(e: native_tls::Error) -> Error { - Error::from(&e) - } -} - -impl From<&native_tls::Error> for Error { - fn from(e: &native_tls::Error) -> Error { - let kind = ErrorKind::TlsError { - message: format!("{e}"), - }; - - Error::builder(kind).build() - } -} diff --git a/quaint/src/connector/postgres/wasm/mod.rs b/quaint/src/connector/postgres/wasm/mod.rs new file mode 100644 index 000000000000..5b330861a199 --- /dev/null +++ b/quaint/src/connector/postgres/wasm/mod.rs @@ -0,0 +1,6 @@ +///! Wasm-compatible definitions for the Postgres connector. +/// /// This module is only available with the `postgresql` feature. +pub(crate) mod common; +pub mod error; + +pub use common::PostgresUrl; diff --git a/quaint/src/error.rs b/quaint/src/error.rs index 785fcc22ffe3..73bf5c405c66 100644 --- a/quaint/src/error.rs +++ b/quaint/src/error.rs @@ -7,7 +7,7 @@ use thiserror::Error; use std::time::Duration; // pub use crate::connector::mysql::MysqlError; -// pub use crate::connector::postgres::PostgresError; +pub use crate::connector::postgres::PostgresError; // pub use crate::connector::sqlite::SqliteError; #[derive(Debug, PartialEq, Eq)] From 12c6ebb15a31d42b5be0301892a122526b6521c7 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 10 Nov 2023 15:49:39 +0100 Subject: [PATCH 03/34] feat(quaint): split mysql connector into native and wasm submodules --- quaint/src/connector.rs | 25 +- quaint/src/connector/mysql.rs | 398 +----------- .../mysql/{ => native}/conversion.rs | 0 quaint/src/connector/mysql/native/error.rs | 36 ++ quaint/src/connector/mysql/native/mod.rs | 392 +++++++++++ quaint/src/connector/mysql/wasm/common.rs | 316 +++++++++ .../src/connector/mysql/{ => wasm}/error.rs | 65 +- quaint/src/connector/mysql/wasm/mod.rs | 6 + quaint/src/connector/postgres.rs | 1 + quaint/src/connector/postgres/native/error.rs | 2 +- quaint/src/connector/postgres_wasm.rs | 612 ------------------ quaint/src/error.rs | 2 +- 12 files changed, 792 insertions(+), 1063 deletions(-) rename quaint/src/connector/mysql/{ => native}/conversion.rs (100%) create mode 100644 quaint/src/connector/mysql/native/error.rs create mode 100644 quaint/src/connector/mysql/native/mod.rs create mode 100644 quaint/src/connector/mysql/wasm/common.rs rename quaint/src/connector/mysql/{ => wasm}/error.rs (81%) create mode 100644 quaint/src/connector/mysql/wasm/mod.rs delete mode 100644 quaint/src/connector/postgres_wasm.rs diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index 71bba2d098ed..d0e4d7e57bdc 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -27,10 +27,10 @@ mod type_identifier; pub(crate) mod mssql; #[cfg(feature = "mssql")] pub(crate) mod mssql_wasm; -#[cfg(feature = "mysql-connector")] -pub(crate) mod mysql; -#[cfg(feature = "mysql")] -pub(crate) mod mysql_wasm; +// #[cfg(feature = "mysql-connector")] +// pub(crate) mod mysql; +// #[cfg(feature = "mysql")] +// pub(crate) mod mysql_wasm; // #[cfg(feature = "postgresql-connector")] // pub(crate) mod postgres; // #[cfg(feature = "postgresql")] @@ -40,10 +40,10 @@ pub(crate) mod sqlite; #[cfg(feature = "sqlite")] pub(crate) mod sqlite_wasm; -#[cfg(feature = "mysql-connector")] -pub use self::mysql::*; -#[cfg(feature = "mysql")] -pub use self::mysql_wasm::*; +// #[cfg(feature = "mysql-connector")] +// pub use self::mysql::*; +// #[cfg(feature = "mysql")] +// pub use self::mysql_wasm::*; // #[cfg(feature = "postgresql-connector")] // pub use self::postgres::*; // #[cfg(feature = "postgresql")] @@ -76,4 +76,11 @@ pub(crate) mod postgres; #[cfg(feature = "postgresql-connector")] pub use postgres::native::*; #[cfg(feature = "postgresql")] -pub use postgres::wasm::*; +pub use postgres::wasm::common::*; + +#[cfg(feature = "mysql")] +pub(crate) mod mysql; +#[cfg(feature = "mysql-connector")] +pub use mysql::native::*; +#[cfg(feature = "mysql")] +pub use mysql::wasm::common::*; diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index a9a829404e76..1794cc738b1e 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -1,394 +1,8 @@ -mod conversion; -mod error; +pub use wasm::common::MysqlUrl; +pub use wasm::error::MysqlError; -pub(crate) use super::mysql_wasm::MysqlUrl; -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use lru_cache::LruCache; -use mysql_async::{ - self as my, - prelude::{Query as _, Queryable as _}, -}; -use std::{ - future::Future, - sync::atomic::{AtomicBool, Ordering}, - time::Duration, -}; -use tokio::sync::Mutex; +#[cfg(feature = "mysql")] +pub(crate) mod wasm; -pub use error::MysqlError; - -/// The underlying MySQL driver. Only available with the `expose-drivers` -/// Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use mysql_async; - -use super::IsolationLevel; - -impl MysqlUrl { - pub(crate) fn cache(&self) -> LruCache { - LruCache::new(self.query_params.statement_cache_size) - } - - pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder { - let mut config = my::OptsBuilder::default() - .stmt_cache_size(Some(0)) - .user(Some(self.username())) - .pass(self.password()) - .db_name(Some(self.dbname())); - - match self.socket() { - Some(ref socket) => { - config = config.socket(Some(socket)); - } - None => { - config = config.ip_or_hostname(self.host()).tcp_port(self.port()); - } - } - - config = config.conn_ttl(Some(Duration::from_secs(5))); - - if self.query_params.use_ssl { - config = config.ssl_opts(Some(self.query_params.ssl_opts.clone())); - } - - if self.query_params.prefer_socket.is_some() { - config = config.prefer_socket(self.query_params.prefer_socket); - } - - config - } -} - -#[derive(Debug, Clone)] -pub(crate) struct MysqlUrlQueryParams { - ssl_opts: my::SslOpts, - connection_limit: Option, - use_ssl: bool, - socket: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, - prefer_socket: Option, - statement_cache_size: usize, -} - -/// A connector interface for the MySQL database. -#[derive(Debug)] -pub struct Mysql { - pub(crate) conn: Mutex, - pub(crate) url: MysqlUrl, - socket_timeout: Option, - is_healthy: AtomicBool, - statement_cache: Mutex>, -} - -impl Mysql { - /// Create a new MySQL connection using `OptsBuilder` from the `mysql` crate. - pub async fn new(url: MysqlUrl) -> crate::Result { - let conn = super::timeout::connect(url.connect_timeout(), my::Conn::new(url.to_opts_builder())).await?; - - Ok(Self { - socket_timeout: url.query_params.socket_timeout, - conn: Mutex::new(conn), - statement_cache: Mutex::new(url.cache()), - url, - is_healthy: AtomicBool::new(true), - }) - } - - /// The underlying mysql_async::Conn. Only available with the - /// `expose-drivers` Cargo feature. This is a lower level API when you need - /// to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn conn(&self) -> &Mutex { - &self.conn - } - - async fn perform_io(&self, op: U) -> crate::Result - where - F: Future>, - U: FnOnce() -> F, - { - match super::timeout::socket(self.socket_timeout, op()).await { - Err(e) if e.is_closed() => { - self.is_healthy.store(false, Ordering::SeqCst); - Err(e) - } - res => Ok(res?), - } - } - - async fn prepared(&self, sql: &str, op: U) -> crate::Result - where - F: Future>, - U: Fn(my::Statement) -> F, - { - if self.url.statement_cache_size() == 0 { - self.perform_io(|| async move { - let stmt = { - let mut conn = self.conn.lock().await; - conn.prep(sql).await? - }; - - let res = op(stmt.clone()).await; - - { - let mut conn = self.conn.lock().await; - conn.close(stmt).await?; - } - - res - }) - .await - } else { - self.perform_io(|| async move { - let stmt = self.fetch_cached(sql).await?; - op(stmt).await - }) - .await - } - } - - async fn fetch_cached(&self, sql: &str) -> crate::Result { - let mut cache = self.statement_cache.lock().await; - let capacity = cache.capacity(); - let stored = cache.len(); - - match cache.get_mut(sql) { - Some(stmt) => { - tracing::trace!( - message = "CACHE HIT!", - query = sql, - capacity = capacity, - stored = stored, - ); - - Ok(stmt.clone()) // arc'd - } - None => { - tracing::trace!( - message = "CACHE MISS!", - query = sql, - capacity = capacity, - stored = stored, - ); - - let mut conn = self.conn.lock().await; - if cache.capacity() == cache.len() { - if let Some((_, stmt)) = cache.remove_lru() { - conn.close(stmt).await?; - } - } - - let stmt = conn.prep(sql).await?; - cache.insert(sql.to_string(), stmt.clone()); - - Ok(stmt) - } - } - } -} - -impl_default_TransactionCapable!(Mysql); - -#[async_trait] -impl Queryable for Mysql { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mysql::build(q)?; - self.query_raw(&sql, ¶ms).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mysql.query_raw", sql, params, move || async move { - self.prepared(sql, |stmt| async move { - let mut conn = self.conn.lock().await; - let rows: Vec = conn.exec(&stmt, conversion::conv_params(params)?).await?; - let columns = stmt.columns().iter().map(|s| s.name_str().into_owned()).collect(); - - let last_id = conn.last_insert_id(); - let mut result_set = ResultSet::new(columns, Vec::new()); - - for mut row in rows { - result_set.rows.push(row.take_result_row()?); - } - - if let Some(id) = last_id { - result_set.set_last_insert_id(id); - }; - - Ok(result_set) - }) - .await - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mysql::build(q)?; - self.execute_raw(&sql, ¶ms).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mysql.execute_raw", sql, params, move || async move { - self.prepared(sql, |stmt| async move { - let mut conn = self.conn.lock().await; - conn.exec_drop(stmt, conversion::conv_params(params)?).await?; - - Ok(conn.affected_rows()) - }) - .await - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("mysql.raw_cmd", cmd, &[], move || async move { - self.perform_io(|| async move { - let mut conn = self.conn.lock().await; - let mut result = cmd.run(&mut *conn).await?; - - loop { - result.map(drop).await?; - - if result.is_empty() { - result.map(drop).await?; - break; - } - } - - Ok(()) - }) - .await - }) - .await - } - - async fn version(&self) -> crate::Result> { - let query = r#"SELECT @@GLOBAL.version version"#; - let rows = super::timeout::socket(self.socket_timeout, self.query_raw(query, &[])).await?; - - let version_string = rows - .get(0) - .and_then(|row| row.get("version").and_then(|version| version.typed.to_string())); - - Ok(version_string) - } - - fn is_healthy(&self) -> bool { - self.is_healthy.load(Ordering::SeqCst) - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - if matches!(isolation_level, IsolationLevel::Snapshot) { - return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); - } - - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await?; - - Ok(()) - } - - fn requires_isolation_first(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use super::MysqlUrl; - use crate::tests::test_api::mysql::CONN_STR; - use crate::{error::*, single::Quaint}; - use url::Url; - - #[test] - fn should_parse_socket_url() { - let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket()); - } - - #[test] - fn should_parse_prefer_socket() { - let url = - MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap(); - assert!(!url.prefer_socket().unwrap()); - } - - #[test] - fn should_parse_sslaccept() { - let url = - MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap(); - assert!(url.query_params.use_ssl); - assert!(!url.query_params.ssl_opts.skip_domain_validation()); - assert!(!url.query_params.ssl_opts.accept_invalid_certs()); - } - - #[test] - fn should_parse_ipv6_host() { - let url = MysqlUrl::new(Url::parse("mysql://[2001:db8:1234::ffff]:5432/testdb").unwrap()).unwrap(); - assert_eq!("2001:db8:1234::ffff", url.host()); - } - - #[test] - fn should_allow_changing_of_cache_size() { - let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap()) - .unwrap(); - assert_eq!(420, url.cache().capacity()); - } - - #[test] - fn should_have_default_cache_size() { - let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap(); - assert_eq!(100, url.cache().capacity()); - } - - #[tokio::test] - async fn should_map_nonexisting_database_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("root").unwrap(); - url.set_path("/this_does_not_exist"); - - let url = url.as_str().to_string(); - let res = Quaint::new(&url).await; - - let err = res.unwrap_err(); - - match err.kind() { - ErrorKind::DatabaseDoesNotExist { db_name } => { - assert_eq!(Some("1049"), err.original_code()); - assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message()); - assert_eq!(&Name::available("this_does_not_exist"), db_name) - } - e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e), - } - } - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("WRONG").unwrap(); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } -} +#[cfg(feature = "mysql-connector")] +pub(crate) mod native; diff --git a/quaint/src/connector/mysql/conversion.rs b/quaint/src/connector/mysql/native/conversion.rs similarity index 100% rename from quaint/src/connector/mysql/conversion.rs rename to quaint/src/connector/mysql/native/conversion.rs diff --git a/quaint/src/connector/mysql/native/error.rs b/quaint/src/connector/mysql/native/error.rs new file mode 100644 index 000000000000..e00ff1e0aa74 --- /dev/null +++ b/quaint/src/connector/mysql/native/error.rs @@ -0,0 +1,36 @@ +use crate::{ + connector::mysql::wasm::error::MysqlError, + error::{Error, ErrorKind}, +}; +use mysql_async as my; + +impl From<&my::ServerError> for MysqlError { + fn from(value: &my::ServerError) -> Self { + MysqlError { + code: value.code, + message: value.message.to_owned(), + state: value.state.to_owned(), + } + } +} + +impl From for Error { + fn from(e: my::Error) -> Error { + match e { + my::Error::Io(my::IoError::Tls(err)) => Error::builder(ErrorKind::TlsError { + message: err.to_string(), + }) + .build(), + my::Error::Io(my::IoError::Io(err)) if err.kind() == std::io::ErrorKind::UnexpectedEof => { + Error::builder(ErrorKind::ConnectionClosed).build() + } + my::Error::Io(io_error) => Error::builder(ErrorKind::ConnectionError(io_error.into())).build(), + my::Error::Driver(e) => Error::builder(ErrorKind::QueryError(e.into())).build(), + my::Error::Server(ref server_error) => { + let mysql_error: MysqlError = server_error.into(); + mysql_error.into() + } + e => Error::builder(ErrorKind::QueryError(e.into())).build(), + } + } +} diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs new file mode 100644 index 000000000000..1a9652b628f8 --- /dev/null +++ b/quaint/src/connector/mysql/native/mod.rs @@ -0,0 +1,392 @@ +mod conversion; +mod error; + +pub(crate) use crate::connector::mysql::wasm::common::MysqlUrl; +use crate::connector::{timeout, IsolationLevel}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, ResultSet}, + error::{Error, ErrorKind}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use lru_cache::LruCache; +use mysql_async::{ + self as my, + prelude::{Query as _, Queryable as _}, +}; +use std::{ + future::Future, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use tokio::sync::Mutex; + +/// The underlying MySQL driver. Only available with the `expose-drivers` +/// Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use mysql_async; + +impl MysqlUrl { + pub(crate) fn cache(&self) -> LruCache { + LruCache::new(self.query_params.statement_cache_size) + } + + pub(crate) fn to_opts_builder(&self) -> my::OptsBuilder { + let mut config = my::OptsBuilder::default() + .stmt_cache_size(Some(0)) + .user(Some(self.username())) + .pass(self.password()) + .db_name(Some(self.dbname())); + + match self.socket() { + Some(ref socket) => { + config = config.socket(Some(socket)); + } + None => { + config = config.ip_or_hostname(self.host()).tcp_port(self.port()); + } + } + + config = config.conn_ttl(Some(Duration::from_secs(5))); + + if self.query_params.use_ssl { + config = config.ssl_opts(Some(self.query_params.ssl_opts.clone())); + } + + if self.query_params.prefer_socket.is_some() { + config = config.prefer_socket(self.query_params.prefer_socket); + } + + config + } +} + +#[derive(Debug, Clone)] +pub(crate) struct MysqlUrlQueryParams { + ssl_opts: my::SslOpts, + connection_limit: Option, + use_ssl: bool, + socket: Option, + socket_timeout: Option, + connect_timeout: Option, + pool_timeout: Option, + max_connection_lifetime: Option, + max_idle_connection_lifetime: Option, + prefer_socket: Option, + statement_cache_size: usize, +} + +/// A connector interface for the MySQL database. +#[derive(Debug)] +pub struct Mysql { + pub(crate) conn: Mutex, + pub(crate) url: MysqlUrl, + socket_timeout: Option, + is_healthy: AtomicBool, + statement_cache: Mutex>, +} + +impl Mysql { + /// Create a new MySQL connection using `OptsBuilder` from the `mysql` crate. + pub async fn new(url: MysqlUrl) -> crate::Result { + let conn = timeout::connect(url.connect_timeout(), my::Conn::new(url.to_opts_builder())).await?; + + Ok(Self { + socket_timeout: url.query_params.socket_timeout, + conn: Mutex::new(conn), + statement_cache: Mutex::new(url.cache()), + url, + is_healthy: AtomicBool::new(true), + }) + } + + /// The underlying mysql_async::Conn. Only available with the + /// `expose-drivers` Cargo feature. This is a lower level API when you need + /// to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn conn(&self) -> &Mutex { + &self.conn + } + + async fn perform_io(&self, op: U) -> crate::Result + where + F: Future>, + U: FnOnce() -> F, + { + match timeout::socket(self.socket_timeout, op()).await { + Err(e) if e.is_closed() => { + self.is_healthy.store(false, Ordering::SeqCst); + Err(e) + } + res => Ok(res?), + } + } + + async fn prepared(&self, sql: &str, op: U) -> crate::Result + where + F: Future>, + U: Fn(my::Statement) -> F, + { + if self.url.statement_cache_size() == 0 { + self.perform_io(|| async move { + let stmt = { + let mut conn = self.conn.lock().await; + conn.prep(sql).await? + }; + + let res = op(stmt.clone()).await; + + { + let mut conn = self.conn.lock().await; + conn.close(stmt).await?; + } + + res + }) + .await + } else { + self.perform_io(|| async move { + let stmt = self.fetch_cached(sql).await?; + op(stmt).await + }) + .await + } + } + + async fn fetch_cached(&self, sql: &str) -> crate::Result { + let mut cache = self.statement_cache.lock().await; + let capacity = cache.capacity(); + let stored = cache.len(); + + match cache.get_mut(sql) { + Some(stmt) => { + tracing::trace!( + message = "CACHE HIT!", + query = sql, + capacity = capacity, + stored = stored, + ); + + Ok(stmt.clone()) // arc'd + } + None => { + tracing::trace!( + message = "CACHE MISS!", + query = sql, + capacity = capacity, + stored = stored, + ); + + let mut conn = self.conn.lock().await; + if cache.capacity() == cache.len() { + if let Some((_, stmt)) = cache.remove_lru() { + conn.close(stmt).await?; + } + } + + let stmt = conn.prep(sql).await?; + cache.insert(sql.to_string(), stmt.clone()); + + Ok(stmt) + } + } + } +} + +impl_default_TransactionCapable!(Mysql); + +#[async_trait] +impl Queryable for Mysql { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mysql::build(q)?; + self.query_raw(&sql, ¶ms).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mysql.query_raw", sql, params, move || async move { + self.prepared(sql, |stmt| async move { + let mut conn = self.conn.lock().await; + let rows: Vec = conn.exec(&stmt, conversion::conv_params(params)?).await?; + let columns = stmt.columns().iter().map(|s| s.name_str().into_owned()).collect(); + + let last_id = conn.last_insert_id(); + let mut result_set = ResultSet::new(columns, Vec::new()); + + for mut row in rows { + result_set.rows.push(row.take_result_row()?); + } + + if let Some(id) = last_id { + result_set.set_last_insert_id(id); + }; + + Ok(result_set) + }) + .await + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mysql::build(q)?; + self.execute_raw(&sql, ¶ms).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mysql.execute_raw", sql, params, move || async move { + self.prepared(sql, |stmt| async move { + let mut conn = self.conn.lock().await; + conn.exec_drop(stmt, conversion::conv_params(params)?).await?; + + Ok(conn.affected_rows()) + }) + .await + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("mysql.raw_cmd", cmd, &[], move || async move { + self.perform_io(|| async move { + let mut conn = self.conn.lock().await; + let mut result = cmd.run(&mut *conn).await?; + + loop { + result.map(drop).await?; + + if result.is_empty() { + result.map(drop).await?; + break; + } + } + + Ok(()) + }) + .await + }) + .await + } + + async fn version(&self) -> crate::Result> { + let query = r#"SELECT @@GLOBAL.version version"#; + let rows = timeout::socket(self.socket_timeout, self.query_raw(query, &[])).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.typed.to_string())); + + Ok(version_string) + } + + fn is_healthy(&self) -> bool { + self.is_healthy.load(Ordering::SeqCst) + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + if matches!(isolation_level, IsolationLevel::Snapshot) { + return Err(Error::builder(ErrorKind::invalid_isolation_level(&isolation_level)).build()); + } + + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await?; + + Ok(()) + } + + fn requires_isolation_first(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::MysqlUrl; + use crate::tests::test_api::mysql::CONN_STR; + use crate::{error::*, single::Quaint}; + use url::Url; + + #[test] + fn should_parse_socket_url() { + let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket()); + } + + #[test] + fn should_parse_prefer_socket() { + let url = + MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap(); + assert!(!url.prefer_socket().unwrap()); + } + + #[test] + fn should_parse_sslaccept() { + let url = + MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap(); + assert!(url.query_params.use_ssl); + assert!(!url.query_params.ssl_opts.skip_domain_validation()); + assert!(!url.query_params.ssl_opts.accept_invalid_certs()); + } + + #[test] + fn should_parse_ipv6_host() { + let url = MysqlUrl::new(Url::parse("mysql://[2001:db8:1234::ffff]:5432/testdb").unwrap()).unwrap(); + assert_eq!("2001:db8:1234::ffff", url.host()); + } + + #[test] + fn should_allow_changing_of_cache_size() { + let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap()) + .unwrap(); + assert_eq!(420, url.cache().capacity()); + } + + #[test] + fn should_have_default_cache_size() { + let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap(); + assert_eq!(100, url.cache().capacity()); + } + + #[tokio::test] + async fn should_map_nonexisting_database_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("root").unwrap(); + url.set_path("/this_does_not_exist"); + + let url = url.as_str().to_string(); + let res = Quaint::new(&url).await; + + let err = res.unwrap_err(); + + match err.kind() { + ErrorKind::DatabaseDoesNotExist { db_name } => { + assert_eq!(Some("1049"), err.original_code()); + assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message()); + assert_eq!(&Name::available("this_does_not_exist"), db_name) + } + e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e), + } + } + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("WRONG").unwrap(); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } +} diff --git a/quaint/src/connector/mysql/wasm/common.rs b/quaint/src/connector/mysql/wasm/common.rs new file mode 100644 index 000000000000..fe60fd24cfc1 --- /dev/null +++ b/quaint/src/connector/mysql/wasm/common.rs @@ -0,0 +1,316 @@ +use crate::error::{Error, ErrorKind}; +use percent_encoding::percent_decode; +use std::{ + borrow::Cow, + path::{Path, PathBuf}, + time::Duration, +}; +use url::{Host, Url}; + +/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. +#[derive(Debug, Clone)] +pub struct MysqlUrl { + url: Url, + pub(crate) query_params: MysqlUrlQueryParams, +} + +impl MysqlUrl { + /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection + /// parameters. + pub fn new(url: Url) -> Result { + let query_params = Self::parse_query_params(&url)?; + + Ok(Self { url, query_params }) + } + + /// The bare `Url` to the database. + pub fn url(&self) -> &Url { + &self.url + } + + /// The percent-decoded database username. + pub fn username(&self) -> Cow { + match percent_decode(self.url.username().as_bytes()).decode_utf8() { + Ok(username) => username, + Err(_) => { + tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); + + self.url.username().into() + } + } + } + + /// The percent-decoded database password. + pub fn password(&self) -> Option> { + match self + .url + .password() + .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) + { + Some(password) => Some(password), + None => self.url.password().map(|s| s.into()), + } + } + + /// Name of the database connected. Defaults to `mysql`. + pub fn dbname(&self) -> &str { + match self.url.path_segments() { + Some(mut segments) => segments.next().unwrap_or("mysql"), + None => "mysql", + } + } + + /// The database host. If `socket` and `host` are not set, defaults to `localhost`. + pub fn host(&self) -> &str { + match (self.url.host(), self.url.host_str()) { + (Some(Host::Ipv6(_)), Some(host)) => { + // The `url` crate may return an IPv6 address in brackets, which must be stripped. + if host.starts_with('[') && host.ends_with(']') { + &host[1..host.len() - 1] + } else { + host + } + } + (_, Some(host)) => host, + _ => "localhost", + } + } + + /// If set, connected to the database through a Unix socket. + pub fn socket(&self) -> &Option { + &self.query_params.socket + } + + /// The database port, defaults to `3306`. + pub fn port(&self) -> u16 { + self.url.port().unwrap_or(3306) + } + + /// The connection timeout. + pub fn connect_timeout(&self) -> Option { + self.query_params.connect_timeout + } + + /// The pool check_out timeout + pub fn pool_timeout(&self) -> Option { + self.query_params.pool_timeout + } + + /// The socket timeout + pub fn socket_timeout(&self) -> Option { + self.query_params.socket_timeout + } + + /// Prefer socket connection + pub fn prefer_socket(&self) -> Option { + self.query_params.prefer_socket + } + + /// The maximum connection lifetime + pub fn max_connection_lifetime(&self) -> Option { + self.query_params.max_connection_lifetime + } + + /// The maximum idle connection lifetime + pub fn max_idle_connection_lifetime(&self) -> Option { + self.query_params.max_idle_connection_lifetime + } + + pub(crate) fn statement_cache_size(&self) -> usize { + self.query_params.statement_cache_size + } + + fn parse_query_params(url: &Url) -> Result { + #[cfg(feature = "mysql-connector")] + let mut ssl_opts = { + let mut ssl_opts = mysql_async::SslOpts::default(); + ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true); + ssl_opts + }; + + let mut connection_limit = None; + let mut use_ssl = false; + let mut socket = None; + let mut socket_timeout = None; + let mut connect_timeout = Some(Duration::from_secs(5)); + let mut pool_timeout = Some(Duration::from_secs(10)); + let mut max_connection_lifetime = None; + let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); + let mut prefer_socket = None; + let mut statement_cache_size = 100; + let mut identity: Option<(Option, Option)> = None; + + for (k, v) in url.query_pairs() { + match k.as_ref() { + "connection_limit" => { + let as_int: usize = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connection_limit = Some(as_int); + } + "statement_cache_size" => { + statement_cache_size = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + } + "sslcert" => { + use_ssl = true; + + #[cfg(feature = "mysql-connector")] + { + ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf())); + } + } + "sslidentity" => { + use_ssl = true; + + identity = match identity { + Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)), + None => Some((Some(Path::new(&*v).to_path_buf()), None)), + }; + } + "sslpassword" => { + use_ssl = true; + + identity = match identity { + Some((path, _)) => Some((path, Some(v.to_string()))), + None => Some((None, Some(v.to_string()))), + }; + } + "socket" => { + socket = Some(v.replace(['(', ')'], "")); + } + "socket_timeout" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + socket_timeout = Some(Duration::from_secs(as_int)); + } + "prefer_socket" => { + let as_bool = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + prefer_socket = Some(as_bool) + } + "connect_timeout" => { + let as_int = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + connect_timeout = match as_int { + 0 => None, + _ => Some(Duration::from_secs(as_int)), + }; + } + "pool_timeout" => { + let as_int = v + .parse::() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + pool_timeout = match as_int { + 0 => None, + _ => Some(Duration::from_secs(as_int)), + }; + } + "sslaccept" => { + use_ssl = true; + match v.as_ref() { + "strict" => { + #[cfg(feature = "mysql-connector")] + { + ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false); + } + } + "accept_invalid_certs" => {} + _ => { + tracing::debug!( + message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`", + mode = &*v + ); + } + }; + } + "max_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_connection_lifetime = None; + } else { + max_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + "max_idle_connection_lifetime" => { + let as_int = v + .parse() + .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; + + if as_int == 0 { + max_idle_connection_lifetime = None; + } else { + max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); + } + } + _ => { + tracing::trace!(message = "Discarding connection string param", param = &*k); + } + }; + } + + // Wrapping this in a block, as attributes on expressions are still experimental + // See: https://github.com/rust-lang/rust/issues/15701 + #[cfg(feature = "mysql-connector")] + { + ssl_opts = match identity { + Some((Some(path), Some(pw))) => { + let identity = mysql_async::ClientIdentity::new(path).with_password(pw); + ssl_opts.with_client_identity(Some(identity)) + } + Some((Some(path), None)) => { + let identity = mysql_async::ClientIdentity::new(path); + ssl_opts.with_client_identity(Some(identity)) + } + _ => ssl_opts, + }; + } + + Ok(MysqlUrlQueryParams { + #[cfg(feature = "mysql-connector")] + ssl_opts, + connection_limit, + use_ssl, + socket, + socket_timeout, + connect_timeout, + pool_timeout, + max_connection_lifetime, + max_idle_connection_lifetime, + prefer_socket, + statement_cache_size, + }) + } + + #[cfg(feature = "pooled")] + pub(crate) fn connection_limit(&self) -> Option { + self.query_params.connection_limit + } +} + +#[derive(Debug, Clone)] +pub(crate) struct MysqlUrlQueryParams { + pub(crate) connection_limit: Option, + pub(crate) use_ssl: bool, + pub(crate) socket: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, + pub(crate) prefer_socket: Option, + pub(crate) statement_cache_size: usize, + + #[cfg(feature = "mysql-connector")] + pub(crate) ssl_opts: mysql_async::SslOpts, +} diff --git a/quaint/src/connector/mysql/error.rs b/quaint/src/connector/mysql/wasm/error.rs similarity index 81% rename from quaint/src/connector/mysql/error.rs rename to quaint/src/connector/mysql/wasm/error.rs index dd7c3d3bfa66..c09ec84d7a7b 100644 --- a/quaint/src/connector/mysql/error.rs +++ b/quaint/src/connector/mysql/wasm/error.rs @@ -1,5 +1,4 @@ use crate::error::{DatabaseConstraint, Error, ErrorKind}; -use mysql_async as my; pub struct MysqlError { pub code: u16, @@ -7,16 +6,6 @@ pub struct MysqlError { pub state: String, } -impl From<&my::ServerError> for MysqlError { - fn from(value: &my::ServerError) -> Self { - MysqlError { - code: value.code, - message: value.message.to_owned(), - state: value.state.to_owned(), - } - } -} - impl From for Error { fn from(error: MysqlError) -> Self { let code = error.code; @@ -230,43 +219,23 @@ impl From for Error { builder.set_original_message(error.message); builder.build() } - _ => { - let kind = ErrorKind::QueryError( - my::Error::Server(my::ServerError { - message: error.message.clone(), - code, - state: error.state.clone(), - }) - .into(), - ); - - let mut builder = Error::builder(kind); - builder.set_original_code(format!("{code}")); - builder.set_original_message(error.message); - - builder.build() - } - } - } -} - -impl From for Error { - fn from(e: my::Error) -> Error { - match e { - my::Error::Io(my::IoError::Tls(err)) => Error::builder(ErrorKind::TlsError { - message: err.to_string(), - }) - .build(), - my::Error::Io(my::IoError::Io(err)) if err.kind() == std::io::ErrorKind::UnexpectedEof => { - Error::builder(ErrorKind::ConnectionClosed).build() - } - my::Error::Io(io_error) => Error::builder(ErrorKind::ConnectionError(io_error.into())).build(), - my::Error::Driver(e) => Error::builder(ErrorKind::QueryError(e.into())).build(), - my::Error::Server(ref server_error) => { - let mysql_error: MysqlError = server_error.into(); - mysql_error.into() - } - e => Error::builder(ErrorKind::QueryError(e.into())).build(), + _ => unimplemented!(), + // _ => { + // let kind = ErrorKind::QueryError( + // my::Error::Server(my::ServerError { + // message: error.message.clone(), + // code, + // state: error.state.clone(), + // }) + // .into(), + // ); + + // let mut builder = Error::builder(kind); + // builder.set_original_code(format!("{code}")); + // builder.set_original_message(error.message); + + // builder.build() + // } } } } diff --git a/quaint/src/connector/mysql/wasm/mod.rs b/quaint/src/connector/mysql/wasm/mod.rs new file mode 100644 index 000000000000..da9a57a53876 --- /dev/null +++ b/quaint/src/connector/mysql/wasm/mod.rs @@ -0,0 +1,6 @@ +///! Wasm-compatible definitions for the MySQL connector. +/// /// This module is only available with the `mysql` feature. +pub(crate) mod common; +pub mod error; + +pub use common::MysqlUrl; diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 9f4d4d496f2b..0f4da84a7c67 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -1,3 +1,4 @@ +pub use wasm::common::PostgresUrl; pub use wasm::error::PostgresError; #[cfg(feature = "postgresql")] diff --git a/quaint/src/connector/postgres/native/error.rs b/quaint/src/connector/postgres/native/error.rs index ec3b18483746..05b792e27900 100644 --- a/quaint/src/connector/postgres/native/error.rs +++ b/quaint/src/connector/postgres/native/error.rs @@ -1,7 +1,7 @@ use tokio_postgres::error::DbError; use crate::{ - connector::error::PostgresError, + connector::postgres::wasm::error::PostgresError, error::{Error, ErrorKind}, }; diff --git a/quaint/src/connector/postgres_wasm.rs b/quaint/src/connector/postgres_wasm.rs deleted file mode 100644 index 4c67b98cfa42..000000000000 --- a/quaint/src/connector/postgres_wasm.rs +++ /dev/null @@ -1,612 +0,0 @@ -use std::{ - borrow::Cow, - fmt::{Debug, Display}, - time::Duration, -}; - -use percent_encoding::percent_decode; -use url::{Host, Url}; - -use crate::error::{Error, ErrorKind}; - -#[cfg(feature = "postgresql-connector")] -use tokio_postgres::config::{ChannelBinding, SslMode}; - -#[derive(Clone)] -pub(crate) struct Hidden(pub(crate) T); - -impl Debug for Hidden { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("") - } -} - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum SslAcceptMode { - Strict, - AcceptInvalidCerts, -} - -#[derive(Debug, Clone)] -pub struct SslParams { - pub(super) certificate_file: Option, - pub(super) identity_file: Option, - pub(super) identity_password: Hidden>, - pub(super) ssl_accept_mode: SslAcceptMode, -} - -#[derive(Debug, Clone, Copy)] -pub enum PostgresFlavour { - Postgres, - Cockroach, - Unknown, -} - -impl PostgresFlavour { - /// Returns `true` if the postgres flavour is [`Postgres`]. - /// - /// [`Postgres`]: PostgresFlavour::Postgres - pub(super) fn is_postgres(&self) -> bool { - matches!(self, Self::Postgres) - } - - /// Returns `true` if the postgres flavour is [`Cockroach`]. - /// - /// [`Cockroach`]: PostgresFlavour::Cockroach - pub(super) fn is_cockroach(&self) -> bool { - matches!(self, Self::Cockroach) - } - - /// Returns `true` if the postgres flavour is [`Unknown`]. - /// - /// [`Unknown`]: PostgresFlavour::Unknown - pub(super) fn is_unknown(&self) -> bool { - matches!(self, Self::Unknown) - } -} - -/// Wraps a connection url and exposes the parsing logic used by Quaint, -/// including default values. -#[derive(Debug, Clone)] -pub struct PostgresUrl { - pub(super) url: Url, - pub(super) query_params: PostgresUrlQueryParams, - pub(super) flavour: PostgresFlavour, -} - -pub(crate) const DEFAULT_SCHEMA: &str = "public"; - -impl PostgresUrl { - /// Parse `Url` to `PostgresUrl`. Returns error for mistyped connection - /// parameters. - pub fn new(url: Url) -> Result { - let query_params = Self::parse_query_params(&url)?; - - Ok(Self { - url, - query_params, - flavour: PostgresFlavour::Unknown, - }) - } - - /// The bare `Url` to the database. - pub fn url(&self) -> &Url { - &self.url - } - - /// The percent-decoded database username. - pub fn username(&self) -> Cow { - match percent_decode(self.url.username().as_bytes()).decode_utf8() { - Ok(username) => username, - Err(_) => { - tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - - self.url.username().into() - } - } - } - - /// The database host. Taken first from the `host` query parameter, then - /// from the `host` part of the URL. For socket connections, the query - /// parameter must be used. - /// - /// If none of them are set, defaults to `localhost`. - pub fn host(&self) -> &str { - match (self.query_params.host.as_ref(), self.url.host_str(), self.url.host()) { - (Some(host), _, _) => host.as_str(), - (None, Some(""), _) => "localhost", - (None, None, _) => "localhost", - (None, Some(host), Some(Host::Ipv6(_))) => { - // The `url` crate may return an IPv6 address in brackets, which must be stripped. - if host.starts_with('[') && host.ends_with(']') { - &host[1..host.len() - 1] - } else { - host - } - } - (None, Some(host), _) => host, - } - } - - /// Name of the database connected. Defaults to `postgres`. - pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("postgres"), - None => "postgres", - } - } - - /// The percent-decoded database password. - pub fn password(&self) -> Cow { - match self - .url - .password() - .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) - { - Some(password) => password, - None => self.url.password().unwrap_or("").into(), - } - } - - /// The database port, defaults to `5432`. - pub fn port(&self) -> u16 { - self.url.port().unwrap_or(5432) - } - - /// The database schema, defaults to `public`. - pub fn schema(&self) -> &str { - self.query_params.schema.as_deref().unwrap_or(DEFAULT_SCHEMA) - } - - /// Whether the pgbouncer mode is enabled. - pub fn pg_bouncer(&self) -> bool { - self.query_params.pg_bouncer - } - - /// The connection timeout. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout - } - - /// Pool check_out timeout - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout - } - - /// The socket timeout - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime - } - - /// The custom application name - pub fn application_name(&self) -> Option<&str> { - self.query_params.application_name.as_deref() - } - - pub(crate) fn options(&self) -> Option<&str> { - self.query_params.options.as_deref() - } - - /// Sets whether the URL points to a Postgres, Cockroach or Unknown database. - /// This is used to avoid a network roundtrip at connection to set the search path. - /// - /// The different behaviours are: - /// - Postgres: Always avoid a network roundtrip by setting the search path through client connection parameters. - /// - Cockroach: Avoid a network roundtrip if the schema name is deemed "safe" (i.e. no escape quoting required). Otherwise, set the search path through a database query. - /// - Unknown: Always add a network roundtrip by setting the search path through a database query. - pub fn set_flavour(&mut self, flavour: PostgresFlavour) { - self.flavour = flavour; - } - - fn parse_query_params(url: &Url) -> Result { - #[cfg(feature = "postgresql-connector")] - let mut ssl_mode = SslMode::Prefer; - #[cfg(feature = "postgresql-connector")] - let mut channel_binding = ChannelBinding::Prefer; - - let mut connection_limit = None; - let mut schema = None; - let mut certificate_file = None; - let mut identity_file = None; - let mut identity_password = None; - let mut ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; - let mut host = None; - let mut application_name = None; - let mut socket_timeout = None; - let mut connect_timeout = Some(Duration::from_secs(5)); - let mut pool_timeout = Some(Duration::from_secs(10)); - let mut pg_bouncer = false; - let mut statement_cache_size = 100; - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); - let mut options = None; - - for (k, v) in url.query_pairs() { - match k.as_ref() { - "pgbouncer" => { - pg_bouncer = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - #[cfg(feature = "postgresql-connector")] - "sslmode" => { - match v.as_ref() { - "disable" => ssl_mode = SslMode::Disable, - "prefer" => ssl_mode = SslMode::Prefer, - "require" => ssl_mode = SslMode::Require, - _ => { - tracing::debug!(message = "Unsupported SSL mode, defaulting to `prefer`", mode = &*v); - } - }; - } - "sslcert" => { - certificate_file = Some(v.to_string()); - } - "sslidentity" => { - identity_file = Some(v.to_string()); - } - "sslpassword" => { - identity_password = Some(v.to_string()); - } - "statement_cache_size" => { - statement_cache_size = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslaccept" => { - match v.as_ref() { - "strict" => { - ssl_accept_mode = SslAcceptMode::Strict; - } - "accept_invalid_certs" => { - ssl_accept_mode = SslAcceptMode::AcceptInvalidCerts; - } - _ => { - tracing::debug!( - message = "Unsupported SSL accept mode, defaulting to `strict`", - mode = &*v - ); - - ssl_accept_mode = SslAcceptMode::Strict; - } - }; - } - "schema" => { - schema = Some(v.to_string()); - } - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - connection_limit = Some(as_int); - } - "host" => { - host = Some(v.to_string()); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - socket_timeout = Some(Duration::from_secs(as_int)); - } - "connect_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - connect_timeout = None; - } else { - connect_timeout = Some(Duration::from_secs(as_int)); - } - } - "pool_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - pool_timeout = None; - } else { - pool_timeout = Some(Duration::from_secs(as_int)); - } - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "application_name" => { - application_name = Some(v.to_string()); - } - #[cfg(feature = "postgresql-connector")] - "channel_binding" => { - match v.as_ref() { - "disable" => channel_binding = ChannelBinding::Disable, - "prefer" => channel_binding = ChannelBinding::Prefer, - "require" => channel_binding = ChannelBinding::Require, - _ => { - tracing::debug!( - message = "Unsupported Channel Binding {channel_binding}, defaulting to `prefer`", - channel_binding = &*v - ); - } - }; - } - "options" => { - options = Some(v.to_string()); - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = &*k); - } - }; - } - - Ok(PostgresUrlQueryParams { - ssl_params: SslParams { - certificate_file, - identity_file, - ssl_accept_mode, - identity_password: Hidden(identity_password), - }, - connection_limit, - schema, - host, - connect_timeout, - pool_timeout, - socket_timeout, - pg_bouncer, - statement_cache_size, - max_connection_lifetime, - max_idle_connection_lifetime, - application_name, - options, - #[cfg(feature = "postgresql-connector")] - channel_binding, - #[cfg(feature = "postgresql-connector")] - ssl_mode, - }) - } - - pub(crate) fn ssl_params(&self) -> &SslParams { - &self.query_params.ssl_params - } - - #[cfg(feature = "pooled")] - pub(crate) fn connection_limit(&self) -> Option { - self.query_params.connection_limit - } - - pub fn flavour(&self) -> PostgresFlavour { - self.flavour - } -} - -#[derive(Debug, Clone)] -pub(crate) struct PostgresUrlQueryParams { - pub(crate) ssl_params: SslParams, - pub(crate) connection_limit: Option, - pub(crate) schema: Option, - pub(crate) pg_bouncer: bool, - pub(crate) host: Option, - pub(crate) socket_timeout: Option, - pub(crate) connect_timeout: Option, - pub(crate) pool_timeout: Option, - pub(crate) statement_cache_size: usize, - pub(crate) max_connection_lifetime: Option, - pub(crate) max_idle_connection_lifetime: Option, - pub(crate) application_name: Option, - pub(crate) options: Option, - - #[cfg(feature = "postgresql-connector")] - pub(crate) channel_binding: ChannelBinding, - - #[cfg(feature = "postgresql-connector")] - pub(crate) ssl_mode: SslMode, -} - -// A SearchPath connection parameter (Display-impl) for connection initialization. -struct CockroachSearchPath<'a>(&'a str); - -impl Display for CockroachSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str(self.0) - } -} - -// A SearchPath connection parameter (Display-impl) for connection initialization. -struct PostgresSearchPath<'a>(&'a str); - -impl Display for PostgresSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.write_str("\"")?; - f.write_str(self.0)?; - f.write_str("\"")?; - - Ok(()) - } -} - -// A SetSearchPath statement (Display-impl) for connection initialization. -struct SetSearchPath<'a>(Option<&'a str>); - -impl Display for SetSearchPath<'_> { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - if let Some(schema) = self.0 { - f.write_str("SET search_path = \"")?; - f.write_str(schema)?; - f.write_str("\";\n")?; - } - - Ok(()) - } -} - -/// Sorted list of CockroachDB's reserved keywords. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_KEYWORDS: [&str; 79] = [ - "all", - "analyse", - "analyze", - "and", - "any", - "array", - "as", - "asc", - "asymmetric", - "both", - "case", - "cast", - "check", - "collate", - "column", - "concurrently", - "constraint", - "create", - "current_catalog", - "current_date", - "current_role", - "current_schema", - "current_time", - "current_timestamp", - "current_user", - "default", - "deferrable", - "desc", - "distinct", - "do", - "else", - "end", - "except", - "false", - "fetch", - "for", - "foreign", - "from", - "grant", - "group", - "having", - "in", - "initially", - "intersect", - "into", - "lateral", - "leading", - "limit", - "localtime", - "localtimestamp", - "not", - "null", - "offset", - "on", - "only", - "or", - "order", - "placing", - "primary", - "references", - "returning", - "select", - "session_user", - "some", - "symmetric", - "table", - "then", - "to", - "trailing", - "true", - "union", - "unique", - "user", - "using", - "variadic", - "when", - "where", - "window", - "with", -]; - -/// Sorted list of CockroachDB's reserved type function names. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ - "authorization", - "collation", - "cross", - "full", - "ilike", - "inner", - "is", - "isnull", - "join", - "left", - "like", - "natural", - "none", - "notnull", - "outer", - "overlaps", - "right", - "similar", -]; - -/// Returns true if a Postgres identifier is considered "safe". -/// -/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. -/// -/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS -/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers -fn is_safe_identifier(ident: &str) -> bool { - if ident.is_empty() { - return false; - } - - // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. - if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { - return false; - } - - let mut chars = ident.chars(); - - let first = chars.next().unwrap(); - - // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). - if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { - return false; - } - - for c in chars { - // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). - if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { - return false; - } - } - - true -} diff --git a/quaint/src/error.rs b/quaint/src/error.rs index 73bf5c405c66..f8202b030466 100644 --- a/quaint/src/error.rs +++ b/quaint/src/error.rs @@ -6,7 +6,7 @@ use thiserror::Error; #[cfg(feature = "pooled")] use std::time::Duration; -// pub use crate::connector::mysql::MysqlError; +pub use crate::connector::mysql::MysqlError; pub use crate::connector::postgres::PostgresError; // pub use crate::connector::sqlite::SqliteError; From 060486d74e7525d7cd61c51accfb2604d629ad75 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 10 Nov 2023 16:05:41 +0100 Subject: [PATCH 04/34] feat(quaint): recover wasm error for mysql --- quaint/src/connector/mysql/wasm/error.rs | 43 ++++++++++++++---------- 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/quaint/src/connector/mysql/wasm/error.rs b/quaint/src/connector/mysql/wasm/error.rs index c09ec84d7a7b..615f0c69dda4 100644 --- a/quaint/src/connector/mysql/wasm/error.rs +++ b/quaint/src/connector/mysql/wasm/error.rs @@ -1,5 +1,15 @@ use crate::error::{DatabaseConstraint, Error, ErrorKind}; +use thiserror::Error; +#[derive(Debug, Error)] +enum MysqlAsyncError { + #[error("Server error: `{}'", _0)] + Server(#[source] MysqlError), +} + +/// This type represents MySql server error. +#[derive(Debug, Error, Clone, Eq, PartialEq)] +#[error("ERROR {} ({}): {}", state, code, message)] pub struct MysqlError { pub code: u16, pub message: String, @@ -219,23 +229,22 @@ impl From for Error { builder.set_original_message(error.message); builder.build() } - _ => unimplemented!(), - // _ => { - // let kind = ErrorKind::QueryError( - // my::Error::Server(my::ServerError { - // message: error.message.clone(), - // code, - // state: error.state.clone(), - // }) - // .into(), - // ); - - // let mut builder = Error::builder(kind); - // builder.set_original_code(format!("{code}")); - // builder.set_original_message(error.message); - - // builder.build() - // } + _ => { + let kind = ErrorKind::QueryError( + MysqlAsyncError::Server(MysqlError { + message: error.message.clone(), + code, + state: error.state.clone(), + }) + .into(), + ); + + let mut builder = Error::builder(kind); + builder.set_original_code(format!("{code}")); + builder.set_original_message(error.message); + + builder.build() + } } } } From 5de1dc0c34513191b1e635a5b4a13c9236fe1399 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 10 Nov 2023 18:19:23 +0100 Subject: [PATCH 05/34] feat(quaint): split mssql connector into native and wasm submodules --- quaint/src/connector.rs | 23 +- quaint/src/connector/mssql.rs | 256 +------------- .../src/connector/mssql/native/conversion.rs | 87 +++++ quaint/src/connector/mssql/native/error.rs | 247 ++++++++++++++ quaint/src/connector/mssql/native/mod.rs | 253 ++++++++++++++ .../{mssql_wasm.rs => mssql/wasm/common.rs} | 45 ++- quaint/src/connector/mssql/wasm/mod.rs | 5 + quaint/src/connector/mysql_wasm.rs | 318 ------------------ 8 files changed, 634 insertions(+), 600 deletions(-) create mode 100644 quaint/src/connector/mssql/native/conversion.rs create mode 100644 quaint/src/connector/mssql/native/error.rs create mode 100644 quaint/src/connector/mssql/native/mod.rs rename quaint/src/connector/{mssql_wasm.rs => mssql/wasm/common.rs} (91%) create mode 100644 quaint/src/connector/mssql/wasm/mod.rs delete mode 100644 quaint/src/connector/mysql_wasm.rs diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index d0e4d7e57bdc..32f9e6186890 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -23,10 +23,10 @@ mod timeout; mod transaction; mod type_identifier; -#[cfg(feature = "mssql-connector")] -pub(crate) mod mssql; -#[cfg(feature = "mssql")] -pub(crate) mod mssql_wasm; +// #[cfg(feature = "mssql-connector")] +// pub(crate) mod mssql; +// #[cfg(feature = "mssql")] +// pub(crate) mod mssql_wasm; // #[cfg(feature = "mysql-connector")] // pub(crate) mod mysql; // #[cfg(feature = "mysql")] @@ -48,10 +48,10 @@ pub(crate) mod sqlite_wasm; // pub use self::postgres::*; // #[cfg(feature = "postgresql")] // pub use self::postgres_wasm::*; -#[cfg(feature = "mssql-connector")] -pub use mssql::*; -#[cfg(feature = "mssql")] -pub use mssql_wasm::*; +// #[cfg(feature = "mssql-connector")] +// pub use mssql::*; +// #[cfg(feature = "mssql")] +// pub use mssql_wasm::*; #[cfg(feature = "sqlite-connector")] pub use sqlite::*; #[cfg(feature = "sqlite")] @@ -84,3 +84,10 @@ pub(crate) mod mysql; pub use mysql::native::*; #[cfg(feature = "mysql")] pub use mysql::wasm::common::*; + +#[cfg(feature = "mssql")] +pub(crate) mod mssql; +#[cfg(feature = "mssql-connector")] +pub use mssql::native::*; +#[cfg(feature = "mssql")] +pub use mssql::wasm::common::*; diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index 16c31551768c..ea681bd08d18 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -1,253 +1,7 @@ -mod conversion; -mod error; +pub use wasm::common::MssqlUrl; -pub(crate) use super::mssql_wasm::MssqlUrl; -use super::{IsolationLevel, Transaction, TransactionOptions}; -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, DefaultTransaction, ResultSet}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use futures::lock::Mutex; -use std::{ - convert::TryFrom, - fmt, - future::Future, - sync::atomic::{AtomicBool, Ordering}, - time::Duration, -}; -use tiberius::*; -use tokio::net::TcpStream; -use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; +#[cfg(feature = "mssql")] +pub(crate) mod wasm; -/// The underlying SQL Server driver. Only available with the `expose-drivers` Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use tiberius; - -static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted; - -#[async_trait] -impl TransactionCapable for Mssql { - async fn start_transaction<'a>( - &'a self, - isolation: Option, - ) -> crate::Result> { - // Isolation levels in SQL Server are set on the connection and live until they're changed. - // Always explicitly setting the isolation level each time a tx is started (either to the given value - // or by using the default/connection string value) prevents transactions started on connections from - // the pool to have unexpected isolation levels set. - let isolation = isolation - .or(self.url.query_params.transaction_isolation_level) - .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); - - let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); - - Ok(Box::new( - DefaultTransaction::new(self, self.begin_statement(), opts).await?, - )) - } -} - -/// A connector interface for the SQL Server database. -#[derive(Debug)] -pub struct Mssql { - client: Mutex>>, - url: MssqlUrl, - socket_timeout: Option, - is_healthy: AtomicBool, -} - -impl Mssql { - /// Creates a new connection to SQL Server. - pub async fn new(url: MssqlUrl) -> crate::Result { - let config = Config::from_jdbc_string(&url.connection_string)?; - let tcp = TcpStream::connect_named(&config).await?; - let socket_timeout = url.socket_timeout(); - - let connecting = async { - match Client::connect(config, tcp.compat_write()).await { - Ok(client) => Ok(client), - Err(tiberius::error::Error::Routing { host, port }) => { - let mut config = Config::from_jdbc_string(&url.connection_string)?; - config.host(host); - config.port(port); - - let tcp = TcpStream::connect_named(&config).await?; - Client::connect(config, tcp.compat_write()).await - } - Err(e) => Err(e), - } - }; - - let client = super::timeout::connect(url.connect_timeout(), connecting).await?; - - let this = Self { - client: Mutex::new(client), - url, - socket_timeout, - is_healthy: AtomicBool::new(true), - }; - - if let Some(isolation) = this.url.transaction_isolation_level() { - this.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation}")) - .await?; - }; - - Ok(this) - } - - /// The underlying Tiberius client. Only available with the `expose-drivers` Cargo feature. - /// This is a lower level API when you need to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn client(&self) -> &Mutex>> { - &self.client - } - - async fn perform_io(&self, fut: F) -> crate::Result - where - F: Future>, - { - match super::timeout::socket(self.socket_timeout, fut).await { - Err(e) if e.is_closed() => { - self.is_healthy.store(false, Ordering::SeqCst); - Err(e) - } - res => res, - } - } -} - -#[async_trait] -impl Queryable for Mssql { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mssql::build(q)?; - self.query_raw(&sql, ¶ms[..]).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mssql.query_raw", sql, params, move || async move { - let mut client = self.client.lock().await; - - let mut query = tiberius::Query::new(sql); - - for param in params { - query.bind(param); - } - - let mut results = self.perform_io(query.query(&mut client)).await?.into_results().await?; - - match results.pop() { - Some(rows) => { - let mut columns_set = false; - let mut columns = Vec::new(); - let mut result_rows = Vec::with_capacity(rows.len()); - - for row in rows.into_iter() { - if !columns_set { - columns = row.columns().iter().map(|c| c.name().to_string()).collect(); - columns_set = true; - } - - let mut values: Vec> = Vec::with_capacity(row.len()); - - for val in row.into_iter() { - values.push(Value::try_from(val)?); - } - - result_rows.push(values); - } - - Ok(ResultSet::new(columns, result_rows)) - } - None => Ok(ResultSet::new(Vec::new(), Vec::new())), - } - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Mssql::build(q)?; - self.execute_raw(&sql, ¶ms[..]).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("mssql.execute_raw", sql, params, move || async move { - let mut query = tiberius::Query::new(sql); - - for param in params { - query.bind(param); - } - - let mut client = self.client.lock().await; - let changes = self.perform_io(query.execute(&mut client)).await?.total(); - - Ok(changes) - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("mssql.raw_cmd", cmd, &[], move || async move { - let mut client = self.client.lock().await; - self.perform_io(client.simple_query(cmd)).await?.into_results().await?; - Ok(()) - }) - .await - } - - async fn version(&self) -> crate::Result> { - let query = r#"SELECT @@VERSION AS version"#; - let rows = self.query_raw(query, &[]).await?; - - let version_string = rows - .get(0) - .and_then(|row| row.get("version").and_then(|version| version.to_string())); - - Ok(version_string) - } - - fn is_healthy(&self) -> bool { - self.is_healthy.load(Ordering::SeqCst) - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) - .await?; - - Ok(()) - } - - fn begin_statement(&self) -> &'static str { - "BEGIN TRAN" - } - - fn requires_isolation_first(&self) -> bool { - true - } -} - -#[cfg(test)] -mod tests { - use crate::tests::test_api::mssql::CONN_STR; - use crate::{error::*, single::Quaint}; - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let url = CONN_STR.replace("user=SA", "user=WRONG"); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } -} +#[cfg(feature = "mssql-connector")] +pub(crate) mod native; diff --git a/quaint/src/connector/mssql/native/conversion.rs b/quaint/src/connector/mssql/native/conversion.rs new file mode 100644 index 000000000000..870654ad5de3 --- /dev/null +++ b/quaint/src/connector/mssql/native/conversion.rs @@ -0,0 +1,87 @@ +use crate::ast::{Value, ValueType}; + +use bigdecimal::BigDecimal; +use std::{borrow::Cow, convert::TryFrom}; + +use tiberius::ToSql; +use tiberius::{ColumnData, FromSql, IntoSql}; + +impl<'a> IntoSql<'a> for &'a Value<'a> { + fn into_sql(self) -> ColumnData<'a> { + match &self.typed { + ValueType::Int32(val) => val.into_sql(), + ValueType::Int64(val) => val.into_sql(), + ValueType::Float(val) => val.into_sql(), + ValueType::Double(val) => val.into_sql(), + ValueType::Text(val) => val.as_deref().into_sql(), + ValueType::Bytes(val) => val.as_deref().into_sql(), + ValueType::Enum(val, _) => val.as_deref().into_sql(), + ValueType::Boolean(val) => val.into_sql(), + ValueType::Char(val) => val.as_ref().map(|val| format!("{val}")).into_sql(), + ValueType::Xml(val) => val.as_deref().into_sql(), + ValueType::Array(_) | ValueType::EnumArray(_, _) => panic!("Arrays are not supported on SQL Server."), + ValueType::Numeric(val) => (*val).to_sql(), + ValueType::Json(val) => val.as_ref().map(|val| serde_json::to_string(&val).unwrap()).into_sql(), + ValueType::Uuid(val) => val.into_sql(), + ValueType::DateTime(val) => val.into_sql(), + ValueType::Date(val) => val.into_sql(), + ValueType::Time(val) => val.into_sql(), + } + } +} + +impl TryFrom> for Value<'static> { + type Error = crate::error::Error; + + fn try_from(cd: ColumnData<'static>) -> crate::Result { + let res = match cd { + ColumnData::U8(num) => ValueType::Int32(num.map(i32::from)), + ColumnData::I16(num) => ValueType::Int32(num.map(i32::from)), + ColumnData::I32(num) => ValueType::Int32(num.map(i32::from)), + ColumnData::I64(num) => ValueType::Int64(num.map(i64::from)), + ColumnData::F32(num) => ValueType::Float(num), + ColumnData::F64(num) => ValueType::Double(num), + ColumnData::Bit(b) => ValueType::Boolean(b), + ColumnData::String(s) => ValueType::Text(s), + ColumnData::Guid(uuid) => ValueType::Uuid(uuid), + ColumnData::Binary(bytes) => ValueType::Bytes(bytes), + numeric @ ColumnData::Numeric(_) => ValueType::Numeric(BigDecimal::from_sql(&numeric)?), + dt @ ColumnData::DateTime(_) => { + use tiberius::time::chrono::{DateTime, NaiveDateTime, Utc}; + + let dt = NaiveDateTime::from_sql(&dt)?.map(|dt| DateTime::::from_utc(dt, Utc)); + ValueType::DateTime(dt) + } + dt @ ColumnData::SmallDateTime(_) => { + use tiberius::time::chrono::{DateTime, NaiveDateTime, Utc}; + + let dt = NaiveDateTime::from_sql(&dt)?.map(|dt| DateTime::::from_utc(dt, Utc)); + ValueType::DateTime(dt) + } + dt @ ColumnData::Time(_) => { + use tiberius::time::chrono::NaiveTime; + + ValueType::Time(NaiveTime::from_sql(&dt)?) + } + dt @ ColumnData::Date(_) => { + use tiberius::time::chrono::NaiveDate; + ValueType::Date(NaiveDate::from_sql(&dt)?) + } + dt @ ColumnData::DateTime2(_) => { + use tiberius::time::chrono::{DateTime, NaiveDateTime, Utc}; + + let dt = NaiveDateTime::from_sql(&dt)?.map(|dt| DateTime::::from_utc(dt, Utc)); + + ValueType::DateTime(dt) + } + dt @ ColumnData::DateTimeOffset(_) => { + use tiberius::time::chrono::{DateTime, Utc}; + + ValueType::DateTime(DateTime::::from_sql(&dt)?) + } + ColumnData::Xml(cow) => ValueType::Xml(cow.map(|xml_data| Cow::Owned(xml_data.into_owned().into_string()))), + }; + + Ok(Value::from(res)) + } +} diff --git a/quaint/src/connector/mssql/native/error.rs b/quaint/src/connector/mssql/native/error.rs new file mode 100644 index 000000000000..f9b6f5e95ab6 --- /dev/null +++ b/quaint/src/connector/mssql/native/error.rs @@ -0,0 +1,247 @@ +use crate::error::{DatabaseConstraint, Error, ErrorKind}; +use tiberius::error::IoErrorKind; + +impl From for Error { + fn from(e: tiberius::error::Error) -> Error { + match e { + tiberius::error::Error::Io { + kind: IoErrorKind::UnexpectedEof, + message, + } => { + let mut builder = Error::builder(ErrorKind::ConnectionClosed); + builder.set_original_message(message); + builder.build() + } + e @ tiberius::error::Error::Io { .. } => Error::builder(ErrorKind::ConnectionError(e.into())).build(), + tiberius::error::Error::Tls(message) => { + let message = format!( + "The TLS settings didn't allow the connection to be established. Please review your connection string. (error: {message})" + ); + + Error::builder(ErrorKind::TlsError { message }).build() + } + tiberius::error::Error::Server(e) if [3902u32, 3903u32, 3971u32].iter().any(|code| e.code() == *code) => { + let kind = ErrorKind::TransactionAlreadyClosed(e.message().to_string()); + + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 8169 => { + let kind = ErrorKind::conversion(e.message().to_string()); + + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 18456 => { + let user = e.message().split('\'').nth(1).into(); + let kind = ErrorKind::AuthenticationFailed { user }; + + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 4060 => { + let db_name = e.message().split('"').nth(1).into(); + let kind = ErrorKind::DatabaseDoesNotExist { db_name }; + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 515 => { + let constraint = e + .message() + .split_whitespace() + .nth(7) + .and_then(|s| s.split('\'').nth(1)) + .map(|s| DatabaseConstraint::fields(Some(s))) + .unwrap_or(DatabaseConstraint::CannotParse); + + let kind = ErrorKind::NullConstraintViolation { constraint }; + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 1801 => { + let db_name = e.message().split('\'').nth(1).into(); + let kind = ErrorKind::DatabaseAlreadyExists { db_name }; + + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 2627 => { + let constraint = e + .message() + .split(". ") + .nth(1) + .and_then(|s| s.split(' ').last()) + .and_then(|s| s.split('\'').nth(1)) + .map(ToString::to_string) + .map(DatabaseConstraint::Index) + .unwrap_or(DatabaseConstraint::CannotParse); + + let kind = ErrorKind::UniqueConstraintViolation { constraint }; + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 547 => { + let constraint = e + .message() + .split('.') + .next() + .and_then(|s| s.split_whitespace().last()) + .and_then(|s| s.split('\"').nth(1)) + .map(ToString::to_string) + .map(DatabaseConstraint::Index) + .unwrap_or(DatabaseConstraint::CannotParse); + + let kind = ErrorKind::ForeignKeyConstraintViolation { constraint }; + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 1505 => { + let constraint = e + .message() + .split('\'') + .nth(3) + .map(ToString::to_string) + .map(DatabaseConstraint::Index) + .unwrap_or(DatabaseConstraint::CannotParse); + + let kind = ErrorKind::UniqueConstraintViolation { constraint }; + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 2601 => { + let constraint = e + .message() + .split_whitespace() + .nth(11) + .and_then(|s| s.split('\'').nth(1)) + .map(ToString::to_string) + .map(DatabaseConstraint::Index) + .unwrap_or(DatabaseConstraint::CannotParse); + + let kind = ErrorKind::UniqueConstraintViolation { constraint }; + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 1801 => { + let db_name = e.message().split('\'').nth(1).into(); + let kind = ErrorKind::DatabaseAlreadyExists { db_name }; + + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 2628 => { + let column = e.message().split('\'').nth(3).into(); + let kind = ErrorKind::LengthMismatch { column }; + + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 208 => { + let table = e + .message() + .split_whitespace() + .nth(3) + .and_then(|s| s.split('\'').nth(1)) + .into(); + + let kind = ErrorKind::TableDoesNotExist { table }; + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 207 => { + let column = e + .message() + .split_whitespace() + .nth(3) + .and_then(|s| s.split('\'').nth(1)) + .into(); + + let kind = ErrorKind::ColumnNotFound { column }; + let mut builder = Error::builder(kind); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 1205 => { + let mut builder = Error::builder(ErrorKind::TransactionWriteConflict); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) if e.code() == 3903 => { + let mut builder = Error::builder(ErrorKind::RollbackWithoutBegin); + + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + tiberius::error::Error::Server(e) => { + let kind = ErrorKind::QueryError(e.clone().into()); + + let mut builder = Error::builder(kind); + builder.set_original_code(format!("{}", e.code())); + builder.set_original_message(e.message().to_string()); + + builder.build() + } + e => Error::builder(ErrorKind::QueryError(e.into())).build(), + } + } +} diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs new file mode 100644 index 000000000000..a1ea3bd5394d --- /dev/null +++ b/quaint/src/connector/mssql/native/mod.rs @@ -0,0 +1,253 @@ +mod conversion; +mod error; + +pub(crate) use crate::connector::mssql::wasm::common::MssqlUrl; +use crate::connector::{timeout, IsolationLevel, Transaction, TransactionOptions}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, DefaultTransaction, ResultSet}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use futures::lock::Mutex; +use std::{ + convert::TryFrom, + future::Future, + sync::atomic::{AtomicBool, Ordering}, + time::Duration, +}; +use tiberius::*; +use tokio::net::TcpStream; +use tokio_util::compat::{Compat, TokioAsyncWriteCompatExt}; + +/// The underlying SQL Server driver. Only available with the `expose-drivers` Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use tiberius; + +static SQL_SERVER_DEFAULT_ISOLATION: IsolationLevel = IsolationLevel::ReadCommitted; + +#[async_trait] +impl TransactionCapable for Mssql { + async fn start_transaction<'a>( + &'a self, + isolation: Option, + ) -> crate::Result> { + // Isolation levels in SQL Server are set on the connection and live until they're changed. + // Always explicitly setting the isolation level each time a tx is started (either to the given value + // or by using the default/connection string value) prevents transactions started on connections from + // the pool to have unexpected isolation levels set. + let isolation = isolation + .or(self.url.query_params.transaction_isolation_level) + .or(Some(SQL_SERVER_DEFAULT_ISOLATION)); + + let opts = TransactionOptions::new(isolation, self.requires_isolation_first()); + + Ok(Box::new( + DefaultTransaction::new(self, self.begin_statement(), opts).await?, + )) + } +} + +/// A connector interface for the SQL Server database. +#[derive(Debug)] +pub struct Mssql { + client: Mutex>>, + url: MssqlUrl, + socket_timeout: Option, + is_healthy: AtomicBool, +} + +impl Mssql { + /// Creates a new connection to SQL Server. + pub async fn new(url: MssqlUrl) -> crate::Result { + let config = Config::from_jdbc_string(&url.connection_string)?; + let tcp = TcpStream::connect_named(&config).await?; + let socket_timeout = url.socket_timeout(); + + let connecting = async { + match Client::connect(config, tcp.compat_write()).await { + Ok(client) => Ok(client), + Err(tiberius::error::Error::Routing { host, port }) => { + let mut config = Config::from_jdbc_string(&url.connection_string)?; + config.host(host); + config.port(port); + + let tcp = TcpStream::connect_named(&config).await?; + Client::connect(config, tcp.compat_write()).await + } + Err(e) => Err(e), + } + }; + + let client = timeout::connect(url.connect_timeout(), connecting).await?; + + let this = Self { + client: Mutex::new(client), + url, + socket_timeout, + is_healthy: AtomicBool::new(true), + }; + + if let Some(isolation) = this.url.transaction_isolation_level() { + this.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation}")) + .await?; + }; + + Ok(this) + } + + /// The underlying Tiberius client. Only available with the `expose-drivers` Cargo feature. + /// This is a lower level API when you need to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn client(&self) -> &Mutex>> { + &self.client + } + + async fn perform_io(&self, fut: F) -> crate::Result + where + F: Future>, + { + match timeout::socket(self.socket_timeout, fut).await { + Err(e) if e.is_closed() => { + self.is_healthy.store(false, Ordering::SeqCst); + Err(e) + } + res => res, + } + } +} + +#[async_trait] +impl Queryable for Mssql { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mssql::build(q)?; + self.query_raw(&sql, ¶ms[..]).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mssql.query_raw", sql, params, move || async move { + let mut client = self.client.lock().await; + + let mut query = tiberius::Query::new(sql); + + for param in params { + query.bind(param); + } + + let mut results = self.perform_io(query.query(&mut client)).await?.into_results().await?; + + match results.pop() { + Some(rows) => { + let mut columns_set = false; + let mut columns = Vec::new(); + let mut result_rows = Vec::with_capacity(rows.len()); + + for row in rows.into_iter() { + if !columns_set { + columns = row.columns().iter().map(|c| c.name().to_string()).collect(); + columns_set = true; + } + + let mut values: Vec> = Vec::with_capacity(row.len()); + + for val in row.into_iter() { + values.push(Value::try_from(val)?); + } + + result_rows.push(values); + } + + Ok(ResultSet::new(columns, result_rows)) + } + None => Ok(ResultSet::new(Vec::new(), Vec::new())), + } + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Mssql::build(q)?; + self.execute_raw(&sql, ¶ms[..]).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("mssql.execute_raw", sql, params, move || async move { + let mut query = tiberius::Query::new(sql); + + for param in params { + query.bind(param); + } + + let mut client = self.client.lock().await; + let changes = self.perform_io(query.execute(&mut client)).await?.total(); + + Ok(changes) + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("mssql.raw_cmd", cmd, &[], move || async move { + let mut client = self.client.lock().await; + self.perform_io(client.simple_query(cmd)).await?.into_results().await?; + Ok(()) + }) + .await + } + + async fn version(&self) -> crate::Result> { + let query = r#"SELECT @@VERSION AS version"#; + let rows = self.query_raw(query, &[]).await?; + + let version_string = rows + .get(0) + .and_then(|row| row.get("version").and_then(|version| version.to_string())); + + Ok(version_string) + } + + fn is_healthy(&self) -> bool { + self.is_healthy.load(Ordering::SeqCst) + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + self.raw_cmd(&format!("SET TRANSACTION ISOLATION LEVEL {isolation_level}")) + .await?; + + Ok(()) + } + + fn begin_statement(&self) -> &'static str { + "BEGIN TRAN" + } + + fn requires_isolation_first(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use crate::tests::test_api::mssql::CONN_STR; + use crate::{error::*, single::Quaint}; + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let url = CONN_STR.replace("user=SA", "user=WRONG"); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } +} diff --git a/quaint/src/connector/mssql_wasm.rs b/quaint/src/connector/mssql/wasm/common.rs similarity index 91% rename from quaint/src/connector/mssql_wasm.rs rename to quaint/src/connector/mssql/wasm/common.rs index d9f7dc27865b..5b6ee881d3e9 100644 --- a/quaint/src/connector/mssql_wasm.rs +++ b/quaint/src/connector/mssql/wasm/common.rs @@ -1,8 +1,7 @@ -#![cfg_attr(target_arch = "wasm32", allow(dead_code))] - -use super::IsolationLevel; - -use crate::error::{Error, ErrorKind}; +use crate::{ + connector::IsolationLevel, + error::{Error, ErrorKind}, +}; use connection_string::JdbcString; use std::{fmt, str::FromStr, time::Duration}; @@ -10,8 +9,8 @@ use std::{fmt, str::FromStr, time::Duration}; /// including default values. #[derive(Debug, Clone)] pub struct MssqlUrl { - pub(super) connection_string: String, - pub(super) query_params: MssqlQueryParams, + pub(crate) connection_string: String, + pub(crate) query_params: MssqlQueryParams, } /// TLS mode when connecting to SQL Server. @@ -51,22 +50,22 @@ impl FromStr for EncryptMode { #[derive(Debug, Clone)] pub(crate) struct MssqlQueryParams { - pub(super) encrypt: EncryptMode, - pub(super) port: Option, - pub(super) host: Option, - pub(super) user: Option, - pub(super) password: Option, - pub(super) database: String, - pub(super) schema: String, - pub(super) trust_server_certificate: bool, - pub(super) trust_server_certificate_ca: Option, - pub(super) connection_limit: Option, - pub(super) socket_timeout: Option, - pub(super) connect_timeout: Option, - pub(super) pool_timeout: Option, - pub(super) transaction_isolation_level: Option, - pub(super) max_connection_lifetime: Option, - pub(super) max_idle_connection_lifetime: Option, + pub(crate) encrypt: EncryptMode, + pub(crate) port: Option, + pub(crate) host: Option, + pub(crate) user: Option, + pub(crate) password: Option, + pub(crate) database: String, + pub(crate) schema: String, + pub(crate) trust_server_certificate: bool, + pub(crate) trust_server_certificate_ca: Option, + pub(crate) connection_limit: Option, + pub(crate) socket_timeout: Option, + pub(crate) connect_timeout: Option, + pub(crate) pool_timeout: Option, + pub(crate) transaction_isolation_level: Option, + pub(crate) max_connection_lifetime: Option, + pub(crate) max_idle_connection_lifetime: Option, } impl MssqlUrl { diff --git a/quaint/src/connector/mssql/wasm/mod.rs b/quaint/src/connector/mssql/wasm/mod.rs new file mode 100644 index 000000000000..69f1f46f7d21 --- /dev/null +++ b/quaint/src/connector/mssql/wasm/mod.rs @@ -0,0 +1,5 @@ +///! Wasm-compatible definitions for the MSSQL connector. +/// This module is only available with the `mssql` feature. +pub(crate) mod common; + +pub use common::MssqlUrl; diff --git a/quaint/src/connector/mysql_wasm.rs b/quaint/src/connector/mysql_wasm.rs deleted file mode 100644 index 24cd525fea33..000000000000 --- a/quaint/src/connector/mysql_wasm.rs +++ /dev/null @@ -1,318 +0,0 @@ -#![cfg_attr(target_arch = "wasm32", allow(dead_code))] - -use crate::error::{Error, ErrorKind}; -use percent_encoding::percent_decode; -use std::{ - borrow::Cow, - path::{Path, PathBuf}, - time::Duration, -}; -use url::{Host, Url}; - -/// Wraps a connection url and exposes the parsing logic used by quaint, including default values. -#[derive(Debug, Clone)] -pub struct MysqlUrl { - url: Url, - pub(super) query_params: MysqlUrlQueryParams, -} - -impl MysqlUrl { - /// Parse `Url` to `MysqlUrl`. Returns error for mistyped connection - /// parameters. - pub fn new(url: Url) -> Result { - let query_params = Self::parse_query_params(&url)?; - - Ok(Self { url, query_params }) - } - - /// The bare `Url` to the database. - pub fn url(&self) -> &Url { - &self.url - } - - /// The percent-decoded database username. - pub fn username(&self) -> Cow { - match percent_decode(self.url.username().as_bytes()).decode_utf8() { - Ok(username) => username, - Err(_) => { - tracing::warn!("Couldn't decode username to UTF-8, using the non-decoded version."); - - self.url.username().into() - } - } - } - - /// The percent-decoded database password. - pub fn password(&self) -> Option> { - match self - .url - .password() - .and_then(|pw| percent_decode(pw.as_bytes()).decode_utf8().ok()) - { - Some(password) => Some(password), - None => self.url.password().map(|s| s.into()), - } - } - - /// Name of the database connected. Defaults to `mysql`. - pub fn dbname(&self) -> &str { - match self.url.path_segments() { - Some(mut segments) => segments.next().unwrap_or("mysql"), - None => "mysql", - } - } - - /// The database host. If `socket` and `host` are not set, defaults to `localhost`. - pub fn host(&self) -> &str { - match (self.url.host(), self.url.host_str()) { - (Some(Host::Ipv6(_)), Some(host)) => { - // The `url` crate may return an IPv6 address in brackets, which must be stripped. - if host.starts_with('[') && host.ends_with(']') { - &host[1..host.len() - 1] - } else { - host - } - } - (_, Some(host)) => host, - _ => "localhost", - } - } - - /// If set, connected to the database through a Unix socket. - pub fn socket(&self) -> &Option { - &self.query_params.socket - } - - /// The database port, defaults to `3306`. - pub fn port(&self) -> u16 { - self.url.port().unwrap_or(3306) - } - - /// The connection timeout. - pub fn connect_timeout(&self) -> Option { - self.query_params.connect_timeout - } - - /// The pool check_out timeout - pub fn pool_timeout(&self) -> Option { - self.query_params.pool_timeout - } - - /// The socket timeout - pub fn socket_timeout(&self) -> Option { - self.query_params.socket_timeout - } - - /// Prefer socket connection - pub fn prefer_socket(&self) -> Option { - self.query_params.prefer_socket - } - - /// The maximum connection lifetime - pub fn max_connection_lifetime(&self) -> Option { - self.query_params.max_connection_lifetime - } - - /// The maximum idle connection lifetime - pub fn max_idle_connection_lifetime(&self) -> Option { - self.query_params.max_idle_connection_lifetime - } - - pub(super) fn statement_cache_size(&self) -> usize { - self.query_params.statement_cache_size - } - - fn parse_query_params(url: &Url) -> Result { - #[cfg(feature = "mysql-connector")] - let mut ssl_opts = { - let mut ssl_opts = mysql_async::SslOpts::default(); - ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true); - ssl_opts - }; - - let mut connection_limit = None; - let mut use_ssl = false; - let mut socket = None; - let mut socket_timeout = None; - let mut connect_timeout = Some(Duration::from_secs(5)); - let mut pool_timeout = Some(Duration::from_secs(10)); - let mut max_connection_lifetime = None; - let mut max_idle_connection_lifetime = Some(Duration::from_secs(300)); - let mut prefer_socket = None; - let mut statement_cache_size = 100; - let mut identity: Option<(Option, Option)> = None; - - for (k, v) in url.query_pairs() { - match k.as_ref() { - "connection_limit" => { - let as_int: usize = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connection_limit = Some(as_int); - } - "statement_cache_size" => { - statement_cache_size = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - } - "sslcert" => { - use_ssl = true; - - #[cfg(feature = "mysql-connector")] - { - ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf())); - } - } - "sslidentity" => { - use_ssl = true; - - identity = match identity { - Some((_, pw)) => Some((Some(Path::new(&*v).to_path_buf()), pw)), - None => Some((Some(Path::new(&*v).to_path_buf()), None)), - }; - } - "sslpassword" => { - use_ssl = true; - - identity = match identity { - Some((path, _)) => Some((path, Some(v.to_string()))), - None => Some((None, Some(v.to_string()))), - }; - } - "socket" => { - socket = Some(v.replace(['(', ')'], "")); - } - "socket_timeout" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - socket_timeout = Some(Duration::from_secs(as_int)); - } - "prefer_socket" => { - let as_bool = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - prefer_socket = Some(as_bool) - } - "connect_timeout" => { - let as_int = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - connect_timeout = match as_int { - 0 => None, - _ => Some(Duration::from_secs(as_int)), - }; - } - "pool_timeout" => { - let as_int = v - .parse::() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - pool_timeout = match as_int { - 0 => None, - _ => Some(Duration::from_secs(as_int)), - }; - } - "sslaccept" => { - use_ssl = true; - match v.as_ref() { - "strict" => { - #[cfg(feature = "mysql-connector")] - { - ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false); - } - } - "accept_invalid_certs" => {} - _ => { - tracing::debug!( - message = "Unsupported SSL accept mode, defaulting to `accept_invalid_certs`", - mode = &*v - ); - } - }; - } - "max_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_connection_lifetime = None; - } else { - max_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - "max_idle_connection_lifetime" => { - let as_int = v - .parse() - .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; - - if as_int == 0 { - max_idle_connection_lifetime = None; - } else { - max_idle_connection_lifetime = Some(Duration::from_secs(as_int)); - } - } - _ => { - tracing::trace!(message = "Discarding connection string param", param = &*k); - } - }; - } - - // Wrapping this in a block, as attributes on expressions are still experimental - // See: https://github.com/rust-lang/rust/issues/15701 - #[cfg(feature = "mysql-connector")] - { - ssl_opts = match identity { - Some((Some(path), Some(pw))) => { - let identity = mysql_async::ClientIdentity::new(path).with_password(pw); - ssl_opts.with_client_identity(Some(identity)) - } - Some((Some(path), None)) => { - let identity = mysql_async::ClientIdentity::new(path); - ssl_opts.with_client_identity(Some(identity)) - } - _ => ssl_opts, - }; - } - - Ok(MysqlUrlQueryParams { - #[cfg(feature = "mysql-connector")] - ssl_opts, - connection_limit, - use_ssl, - socket, - socket_timeout, - connect_timeout, - pool_timeout, - max_connection_lifetime, - max_idle_connection_lifetime, - prefer_socket, - statement_cache_size, - }) - } - - #[cfg(feature = "pooled")] - pub(crate) fn connection_limit(&self) -> Option { - self.query_params.connection_limit - } -} - -#[derive(Debug, Clone)] -pub(crate) struct MysqlUrlQueryParams { - pub(crate) connection_limit: Option, - pub(crate) use_ssl: bool, - pub(crate) socket: Option, - pub(crate) socket_timeout: Option, - pub(crate) connect_timeout: Option, - pub(crate) pool_timeout: Option, - pub(crate) max_connection_lifetime: Option, - pub(crate) max_idle_connection_lifetime: Option, - pub(crate) prefer_socket: Option, - pub(crate) statement_cache_size: usize, - - #[cfg(feature = "mysql-connector")] - pub(crate) ssl_opts: mysql_async::SslOpts, -} From 8ecbc5c37d71513eeea080fd3bb1ef08618b080d Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 13 Nov 2023 12:25:20 +0100 Subject: [PATCH 06/34] feat(quaint): split sqlite connector into native and wasm submodules --- quaint/Cargo.toml | 6 +- quaint/src/connector.rs | 23 +- quaint/src/connector/sqlite.rs | 256 +----------------- .../sqlite/{ => native}/conversion.rs | 0 quaint/src/connector/sqlite/native/error.rs | 49 ++++ quaint/src/connector/sqlite/native/mod.rs | 252 +++++++++++++++++ .../{sqlite_wasm.rs => sqlite/wasm/common.rs} | 0 .../src/connector/sqlite/{ => wasm}/error.rs | 62 +---- quaint/src/connector/sqlite/wasm/mod.rs | 4 + quaint/src/error.rs | 2 +- 10 files changed, 336 insertions(+), 318 deletions(-) rename quaint/src/connector/sqlite/{ => native}/conversion.rs (100%) create mode 100644 quaint/src/connector/sqlite/native/error.rs create mode 100644 quaint/src/connector/sqlite/native/mod.rs rename quaint/src/connector/{sqlite_wasm.rs => sqlite/wasm/common.rs} (100%) rename quaint/src/connector/sqlite/{ => wasm}/error.rs (69%) create mode 100644 quaint/src/connector/sqlite/wasm/mod.rs diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index 2da9ec0929c0..abe9fece9746 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -70,8 +70,8 @@ mysql-connector = ["mysql", "mysql_async", "tokio/time", "lru-cache"] mysql = ["chrono/std"] pooled = ["mobc"] -sqlite-connector = ["sqlite", "rusqlite", "tokio/sync"] -sqlite = [] +sqlite-connector = ["sqlite", "rusqlite/bundled", "tokio/sync"] +sqlite = ["rusqlite"] fmt-sql = ["sqlformat"] @@ -127,7 +127,7 @@ branch = "vendored-openssl" [dependencies.rusqlite] version = "0.29" -features = ["chrono", "bundled", "column_decltype"] +features = ["chrono", "column_decltype"] optional = true [target.'cfg(not(any(target_os = "macos", target_os = "ios")))'.dependencies.tiberius] diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index 32f9e6186890..b182e60a4387 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -35,10 +35,10 @@ mod type_identifier; // pub(crate) mod postgres; // #[cfg(feature = "postgresql")] // pub(crate) mod postgres_wasm; -#[cfg(feature = "sqlite-connector")] -pub(crate) mod sqlite; -#[cfg(feature = "sqlite")] -pub(crate) mod sqlite_wasm; +// #[cfg(feature = "sqlite-connector")] +// pub(crate) mod sqlite; +// #[cfg(feature = "sqlite")] +// pub(crate) mod sqlite_wasm; // #[cfg(feature = "mysql-connector")] // pub use self::mysql::*; @@ -52,10 +52,10 @@ pub(crate) mod sqlite_wasm; // pub use mssql::*; // #[cfg(feature = "mssql")] // pub use mssql_wasm::*; -#[cfg(feature = "sqlite-connector")] -pub use sqlite::*; -#[cfg(feature = "sqlite")] -pub use sqlite_wasm::*; +// #[cfg(feature = "sqlite-connector")] +// pub use sqlite::*; +// #[cfg(feature = "sqlite")] +// pub use sqlite_wasm::*; pub use self::result_set::*; pub use connection_info::*; @@ -85,6 +85,13 @@ pub use mysql::native::*; #[cfg(feature = "mysql")] pub use mysql::wasm::common::*; +#[cfg(feature = "sqlite")] +pub(crate) mod sqlite; +#[cfg(feature = "sqlite-connector")] +pub use sqlite::native::*; +#[cfg(feature = "sqlite")] +pub use sqlite::wasm::common::*; + #[cfg(feature = "mssql")] pub(crate) mod mssql; #[cfg(feature = "mssql-connector")] diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index fc993c1eaf0e..0e699c211878 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -1,253 +1,7 @@ -mod conversion; -mod error; +pub use wasm::error::SqliteError; -pub(crate) use super::sqlite_wasm::{SqliteParams, DEFAULT_SQLITE_SCHEMA_NAME}; -pub use error::SqliteError; +#[cfg(feature = "sqlite")] +pub(crate) mod wasm; -pub use rusqlite::{params_from_iter, version as sqlite_version}; - -use super::IsolationLevel; -use crate::{ - ast::{Query, Value}, - connector::{metrics, queryable::*, ResultSet}, - error::{Error, ErrorKind}, - visitor::{self, Visitor}, -}; -use async_trait::async_trait; -use std::convert::TryFrom; -use tokio::sync::Mutex; - -/// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. -#[cfg(feature = "expose-drivers")] -pub use rusqlite; - -/// A connector interface for the SQLite database -pub struct Sqlite { - pub(crate) client: Mutex, -} - -impl TryFrom<&str> for Sqlite { - type Error = Error; - - fn try_from(path: &str) -> crate::Result { - let params = SqliteParams::try_from(path)?; - let file_path = params.file_path; - - let conn = rusqlite::Connection::open(file_path.as_str())?; - - if let Some(timeout) = params.socket_timeout { - conn.busy_timeout(timeout)?; - }; - - let client = Mutex::new(conn); - - Ok(Sqlite { client }) - } -} - -impl Sqlite { - pub fn new(file_path: &str) -> crate::Result { - Self::try_from(file_path) - } - - /// Open a new SQLite database in memory. - pub fn new_in_memory() -> crate::Result { - let client = rusqlite::Connection::open_in_memory()?; - - Ok(Sqlite { - client: Mutex::new(client), - }) - } - - /// The underlying rusqlite::Connection. Only available with the `expose-drivers` Cargo - /// feature. This is a lower level API when you need to get into database specific features. - #[cfg(feature = "expose-drivers")] - pub fn connection(&self) -> &Mutex { - &self.client - } -} - -impl_default_TransactionCapable!(Sqlite); - -#[async_trait] -impl Queryable for Sqlite { - async fn query(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Sqlite::build(q)?; - self.query_raw(&sql, ¶ms).await - } - - async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("sqlite.query_raw", sql, params, move || async move { - let client = self.client.lock().await; - - let mut stmt = client.prepare_cached(sql)?; - - let mut rows = stmt.query(params_from_iter(params.iter()))?; - let mut result = ResultSet::new(rows.to_column_names(), Vec::new()); - - while let Some(row) = rows.next()? { - result.rows.push(row.get_result_row()?); - } - - result.set_last_insert_id(u64::try_from(client.last_insert_rowid()).unwrap_or(0)); - - Ok(result) - }) - .await - } - - async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.query_raw(sql, params).await - } - - async fn execute(&self, q: Query<'_>) -> crate::Result { - let (sql, params) = visitor::Sqlite::build(q)?; - self.execute_raw(&sql, ¶ms).await - } - - async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - metrics::query("sqlite.query_raw", sql, params, move || async move { - let client = self.client.lock().await; - let mut stmt = client.prepare_cached(sql)?; - let res = u64::try_from(stmt.execute(params_from_iter(params.iter()))?)?; - - Ok(res) - }) - .await - } - - async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { - self.execute_raw(sql, params).await - } - - async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { - metrics::query("sqlite.raw_cmd", cmd, &[], move || async move { - let client = self.client.lock().await; - client.execute_batch(cmd)?; - Ok(()) - }) - .await - } - - async fn version(&self) -> crate::Result> { - Ok(Some(rusqlite::version().into())) - } - - fn is_healthy(&self) -> bool { - true - } - - async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { - // SQLite is always "serializable", other modes involve pragmas - // and shared cache mode, which is out of scope for now and should be implemented - // as part of a separate effort. - if !matches!(isolation_level, IsolationLevel::Serializable) { - let kind = ErrorKind::invalid_isolation_level(&isolation_level); - return Err(Error::builder(kind).build()); - } - - Ok(()) - } - - fn requires_isolation_first(&self) -> bool { - false - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::{ - ast::*, - connector::Queryable, - error::{ErrorKind, Name}, - }; - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_file_scheme() { - let path = "file:dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_sqlite_scheme() { - let path = "sqlite:dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_no_scheme() { - let path = "dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[tokio::test] - async fn unknown_table_should_give_a_good_error() { - let conn = Sqlite::try_from("file:db/test.db").unwrap(); - let select = Select::from_table("not_there"); - - let err = conn.select(select).await.unwrap_err(); - - match err.kind() { - ErrorKind::TableDoesNotExist { table } => { - assert_eq!(&Name::available("not_there"), table); - } - e => panic!("Expected error TableDoesNotExist, got {:?}", e), - } - } - - #[tokio::test] - async fn in_memory_sqlite_works() { - let conn = Sqlite::new_in_memory().unwrap(); - - conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, txt TEXT NOT NULL);") - .await - .unwrap(); - - let insert = Insert::single_into("test").value("txt", "henlo"); - conn.insert(insert.into()).await.unwrap(); - - let select = Select::from_table("test").value(asterisk()); - let result = conn.select(select.clone()).await.unwrap(); - let result = result.into_single().unwrap(); - - assert_eq!(result.get("id").unwrap(), &Value::int32(1)); - assert_eq!(result.get("txt").unwrap(), &Value::text("henlo")); - - // Check that we do get a separate, new database. - let other_conn = Sqlite::new_in_memory().unwrap(); - - let err = other_conn.select(select).await.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::TableDoesNotExist { .. })); - } - - #[tokio::test] - async fn quoting_in_returning_in_sqlite_works() { - let conn = Sqlite::new_in_memory().unwrap(); - - conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, `txt space` TEXT NOT NULL);") - .await - .unwrap(); - - let insert = Insert::single_into("test").value("txt space", "henlo"); - conn.insert(insert.into()).await.unwrap(); - - let select = Select::from_table("test").value(asterisk()); - let result = conn.select(select.clone()).await.unwrap(); - let result = result.into_single().unwrap(); - - assert_eq!(result.get("id").unwrap(), &Value::int32(1)); - assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); - - let insert = Insert::single_into("test").value("txt space", "henlo"); - let insert: Insert = Insert::from(insert).returning(["txt space"]); - - let result = conn.insert(insert).await.unwrap(); - let result = result.into_single().unwrap(); - - assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); - } -} +#[cfg(feature = "sqlite-connector")] +pub(crate) mod native; diff --git a/quaint/src/connector/sqlite/conversion.rs b/quaint/src/connector/sqlite/native/conversion.rs similarity index 100% rename from quaint/src/connector/sqlite/conversion.rs rename to quaint/src/connector/sqlite/native/conversion.rs diff --git a/quaint/src/connector/sqlite/native/error.rs b/quaint/src/connector/sqlite/native/error.rs new file mode 100644 index 000000000000..9e2b2e7c3ea1 --- /dev/null +++ b/quaint/src/connector/sqlite/native/error.rs @@ -0,0 +1,49 @@ +use crate::connector::sqlite::wasm::error::SqliteError; + +use crate::error::*; + +impl From for Error { + fn from(e: rusqlite::Error) -> Error { + match e { + rusqlite::Error::ToSqlConversionFailure(error) => match error.downcast::() { + Ok(error) => *error, + Err(error) => { + let mut builder = Error::builder(ErrorKind::QueryError(error)); + + builder.set_original_message("Could not interpret parameters in an SQLite query."); + + builder.build() + } + }, + rusqlite::Error::InvalidQuery => { + let mut builder = Error::builder(ErrorKind::QueryError(e.into())); + + builder.set_original_message( + "Could not interpret the query or its parameters. Check the syntax and parameter types.", + ); + + builder.build() + } + rusqlite::Error::ExecuteReturnedResults => { + let mut builder = Error::builder(ErrorKind::QueryError(e.into())); + builder.set_original_message("Execute returned results, which is not allowed in SQLite."); + + builder.build() + } + + rusqlite::Error::QueryReturnedNoRows => Error::builder(ErrorKind::NotFound).build(), + + rusqlite::Error::SqliteFailure(rusqlite::ffi::Error { code: _, extended_code }, message) => { + SqliteError::new(extended_code, message).into() + } + + rusqlite::Error::SqlInputError { + error: rusqlite::ffi::Error { extended_code, .. }, + msg, + .. + } => SqliteError::new(extended_code, Some(msg)).into(), + + e => Error::builder(ErrorKind::QueryError(e.into())).build(), + } + } +} diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs new file mode 100644 index 000000000000..e11f6cd021bc --- /dev/null +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -0,0 +1,252 @@ +mod conversion; +mod error; + +use crate::connector::sqlite::wasm::common::SqliteParams; +use crate::connector::IsolationLevel; + +pub use rusqlite::{params_from_iter, version as sqlite_version}; + +use crate::{ + ast::{Query, Value}, + connector::{metrics, queryable::*, ResultSet}, + error::{Error, ErrorKind}, + visitor::{self, Visitor}, +}; +use async_trait::async_trait; +use std::convert::TryFrom; +use tokio::sync::Mutex; + +/// The underlying sqlite driver. Only available with the `expose-drivers` Cargo feature. +#[cfg(feature = "expose-drivers")] +pub use rusqlite; + +/// A connector interface for the SQLite database +pub struct Sqlite { + pub(crate) client: Mutex, +} + +impl TryFrom<&str> for Sqlite { + type Error = Error; + + fn try_from(path: &str) -> crate::Result { + let params = SqliteParams::try_from(path)?; + let file_path = params.file_path; + + let conn = rusqlite::Connection::open(file_path.as_str())?; + + if let Some(timeout) = params.socket_timeout { + conn.busy_timeout(timeout)?; + }; + + let client = Mutex::new(conn); + + Ok(Sqlite { client }) + } +} + +impl Sqlite { + pub fn new(file_path: &str) -> crate::Result { + Self::try_from(file_path) + } + + /// Open a new SQLite database in memory. + pub fn new_in_memory() -> crate::Result { + let client = rusqlite::Connection::open_in_memory()?; + + Ok(Sqlite { + client: Mutex::new(client), + }) + } + + /// The underlying rusqlite::Connection. Only available with the `expose-drivers` Cargo + /// feature. This is a lower level API when you need to get into database specific features. + #[cfg(feature = "expose-drivers")] + pub fn connection(&self) -> &Mutex { + &self.client + } +} + +impl_default_TransactionCapable!(Sqlite); + +#[async_trait] +impl Queryable for Sqlite { + async fn query(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Sqlite::build(q)?; + self.query_raw(&sql, ¶ms).await + } + + async fn query_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("sqlite.query_raw", sql, params, move || async move { + let client = self.client.lock().await; + + let mut stmt = client.prepare_cached(sql)?; + + let mut rows = stmt.query(params_from_iter(params.iter()))?; + let mut result = ResultSet::new(rows.to_column_names(), Vec::new()); + + while let Some(row) = rows.next()? { + result.rows.push(row.get_result_row()?); + } + + result.set_last_insert_id(u64::try_from(client.last_insert_rowid()).unwrap_or(0)); + + Ok(result) + }) + .await + } + + async fn query_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.query_raw(sql, params).await + } + + async fn execute(&self, q: Query<'_>) -> crate::Result { + let (sql, params) = visitor::Sqlite::build(q)?; + self.execute_raw(&sql, ¶ms).await + } + + async fn execute_raw(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + metrics::query("sqlite.query_raw", sql, params, move || async move { + let client = self.client.lock().await; + let mut stmt = client.prepare_cached(sql)?; + let res = u64::try_from(stmt.execute(params_from_iter(params.iter()))?)?; + + Ok(res) + }) + .await + } + + async fn execute_raw_typed(&self, sql: &str, params: &[Value<'_>]) -> crate::Result { + self.execute_raw(sql, params).await + } + + async fn raw_cmd(&self, cmd: &str) -> crate::Result<()> { + metrics::query("sqlite.raw_cmd", cmd, &[], move || async move { + let client = self.client.lock().await; + client.execute_batch(cmd)?; + Ok(()) + }) + .await + } + + async fn version(&self) -> crate::Result> { + Ok(Some(rusqlite::version().into())) + } + + fn is_healthy(&self) -> bool { + true + } + + async fn set_tx_isolation_level(&self, isolation_level: IsolationLevel) -> crate::Result<()> { + // SQLite is always "serializable", other modes involve pragmas + // and shared cache mode, which is out of scope for now and should be implemented + // as part of a separate effort. + if !matches!(isolation_level, IsolationLevel::Serializable) { + let kind = ErrorKind::invalid_isolation_level(&isolation_level); + return Err(Error::builder(kind).build()); + } + + Ok(()) + } + + fn requires_isolation_first(&self) -> bool { + false + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + ast::*, + connector::Queryable, + error::{ErrorKind, Name}, + }; + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_file_scheme() { + let path = "file:dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_sqlite_scheme() { + let path = "sqlite:dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_no_scheme() { + let path = "dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } + + #[tokio::test] + async fn unknown_table_should_give_a_good_error() { + let conn = Sqlite::try_from("file:db/test.db").unwrap(); + let select = Select::from_table("not_there"); + + let err = conn.select(select).await.unwrap_err(); + + match err.kind() { + ErrorKind::TableDoesNotExist { table } => { + assert_eq!(&Name::available("not_there"), table); + } + e => panic!("Expected error TableDoesNotExist, got {:?}", e), + } + } + + #[tokio::test] + async fn in_memory_sqlite_works() { + let conn = Sqlite::new_in_memory().unwrap(); + + conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, txt TEXT NOT NULL);") + .await + .unwrap(); + + let insert = Insert::single_into("test").value("txt", "henlo"); + conn.insert(insert.into()).await.unwrap(); + + let select = Select::from_table("test").value(asterisk()); + let result = conn.select(select.clone()).await.unwrap(); + let result = result.into_single().unwrap(); + + assert_eq!(result.get("id").unwrap(), &Value::int32(1)); + assert_eq!(result.get("txt").unwrap(), &Value::text("henlo")); + + // Check that we do get a separate, new database. + let other_conn = Sqlite::new_in_memory().unwrap(); + + let err = other_conn.select(select).await.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::TableDoesNotExist { .. })); + } + + #[tokio::test] + async fn quoting_in_returning_in_sqlite_works() { + let conn = Sqlite::new_in_memory().unwrap(); + + conn.raw_cmd("CREATE TABLE test (id INTEGER PRIMARY KEY, `txt space` TEXT NOT NULL);") + .await + .unwrap(); + + let insert = Insert::single_into("test").value("txt space", "henlo"); + conn.insert(insert.into()).await.unwrap(); + + let select = Select::from_table("test").value(asterisk()); + let result = conn.select(select.clone()).await.unwrap(); + let result = result.into_single().unwrap(); + + assert_eq!(result.get("id").unwrap(), &Value::int32(1)); + assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); + + let insert = Insert::single_into("test").value("txt space", "henlo"); + let insert: Insert = Insert::from(insert).returning(["txt space"]); + + let result = conn.insert(insert).await.unwrap(); + let result = result.into_single().unwrap(); + + assert_eq!(result.get("txt space").unwrap(), &Value::text("henlo")); + } +} diff --git a/quaint/src/connector/sqlite_wasm.rs b/quaint/src/connector/sqlite/wasm/common.rs similarity index 100% rename from quaint/src/connector/sqlite_wasm.rs rename to quaint/src/connector/sqlite/wasm/common.rs diff --git a/quaint/src/connector/sqlite/error.rs b/quaint/src/connector/sqlite/wasm/error.rs similarity index 69% rename from quaint/src/connector/sqlite/error.rs rename to quaint/src/connector/sqlite/wasm/error.rs index c10b335cb3c0..9cd0ef64e8a4 100644 --- a/quaint/src/connector/sqlite/error.rs +++ b/quaint/src/connector/sqlite/wasm/error.rs @@ -1,8 +1,6 @@ use std::fmt; use crate::error::*; -use rusqlite::ffi; -use rusqlite::types::FromSqlError; #[derive(Debug)] pub struct SqliteError { @@ -16,7 +14,7 @@ impl fmt::Display for SqliteError { f, "Error code {}: {}", self.extended_code, - ffi::code_to_str(self.extended_code) + rusqlite::ffi::code_to_str(self.extended_code) ) } } @@ -37,7 +35,7 @@ impl From for Error { fn from(error: SqliteError) -> Self { match error { SqliteError { - extended_code: ffi::SQLITE_CONSTRAINT_UNIQUE | ffi::SQLITE_CONSTRAINT_PRIMARYKEY, + extended_code: rusqlite::ffi::SQLITE_CONSTRAINT_UNIQUE | rusqlite::ffi::SQLITE_CONSTRAINT_PRIMARYKEY, message: Some(description), } => { let constraint = description @@ -58,7 +56,7 @@ impl From for Error { } SqliteError { - extended_code: ffi::SQLITE_CONSTRAINT_NOTNULL, + extended_code: rusqlite::ffi::SQLITE_CONSTRAINT_NOTNULL, message: Some(description), } => { let constraint = description @@ -79,7 +77,7 @@ impl From for Error { } SqliteError { - extended_code: ffi::SQLITE_CONSTRAINT_FOREIGNKEY | ffi::SQLITE_CONSTRAINT_TRIGGER, + extended_code: rusqlite::ffi::SQLITE_CONSTRAINT_FOREIGNKEY | rusqlite::ffi::SQLITE_CONSTRAINT_TRIGGER, message: Some(description), } => { let mut builder = Error::builder(ErrorKind::ForeignKeyConstraintViolation { @@ -92,7 +90,7 @@ impl From for Error { builder.build() } - SqliteError { extended_code, message } if error.primary_code() == ffi::SQLITE_BUSY => { + SqliteError { extended_code, message } if error.primary_code() == rusqlite::ffi::SQLITE_BUSY => { let mut builder = Error::builder(ErrorKind::SocketTimeout); builder.set_original_code(format!("{extended_code}")); @@ -153,54 +151,8 @@ impl From for Error { } } -impl From for Error { - fn from(e: rusqlite::Error) -> Error { - match e { - rusqlite::Error::ToSqlConversionFailure(error) => match error.downcast::() { - Ok(error) => *error, - Err(error) => { - let mut builder = Error::builder(ErrorKind::QueryError(error)); - - builder.set_original_message("Could not interpret parameters in an SQLite query."); - - builder.build() - } - }, - rusqlite::Error::InvalidQuery => { - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - - builder.set_original_message( - "Could not interpret the query or its parameters. Check the syntax and parameter types.", - ); - - builder.build() - } - rusqlite::Error::ExecuteReturnedResults => { - let mut builder = Error::builder(ErrorKind::QueryError(e.into())); - builder.set_original_message("Execute returned results, which is not allowed in SQLite."); - - builder.build() - } - - rusqlite::Error::QueryReturnedNoRows => Error::builder(ErrorKind::NotFound).build(), - - rusqlite::Error::SqliteFailure(ffi::Error { code: _, extended_code }, message) => { - SqliteError::new(extended_code, message).into() - } - - rusqlite::Error::SqlInputError { - error: ffi::Error { extended_code, .. }, - msg, - .. - } => SqliteError::new(extended_code, Some(msg)).into(), - - e => Error::builder(ErrorKind::QueryError(e.into())).build(), - } - } -} - -impl From for Error { - fn from(e: FromSqlError) -> Error { +impl From for Error { + fn from(e: rusqlite::types::FromSqlError) -> Error { Error::builder(ErrorKind::ColumnReadFailure(e.into())).build() } } diff --git a/quaint/src/connector/sqlite/wasm/mod.rs b/quaint/src/connector/sqlite/wasm/mod.rs new file mode 100644 index 000000000000..0dbbcd76daec --- /dev/null +++ b/quaint/src/connector/sqlite/wasm/mod.rs @@ -0,0 +1,4 @@ +///! Wasm-compatible definitions for the SQLite connector. +/// /// This module is only available with the `sqlite` feature. +pub(crate) mod common; +pub mod error; diff --git a/quaint/src/error.rs b/quaint/src/error.rs index f8202b030466..705bb6b37ee0 100644 --- a/quaint/src/error.rs +++ b/quaint/src/error.rs @@ -8,7 +8,7 @@ use std::time::Duration; pub use crate::connector::mysql::MysqlError; pub use crate::connector::postgres::PostgresError; -// pub use crate::connector::sqlite::SqliteError; +pub use crate::connector::sqlite::SqliteError; #[derive(Debug, PartialEq, Eq)] pub enum DatabaseConstraint { From 45df24fdd0ba18c192ff3bbf4d363d69b8ae4f5e Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 13 Nov 2023 13:32:01 +0100 Subject: [PATCH 07/34] chore(quaint): fix clippy when compiling natively --- quaint/src/connector.rs | 34 --- quaint/src/connector/mssql/conversion.rs | 87 ------- quaint/src/connector/mssql/error.rs | 247 ------------------- quaint/src/connector/mssql/native/mod.rs | 3 + quaint/src/connector/mssql/wasm/mod.rs | 4 +- quaint/src/connector/mysql/native/mod.rs | 18 +- quaint/src/connector/mysql/wasm/mod.rs | 4 +- quaint/src/connector/postgres/native/mod.rs | 9 +- quaint/src/connector/postgres/wasm/common.rs | 142 ----------- quaint/src/connector/postgres/wasm/mod.rs | 4 +- quaint/src/connector/sqlite/native/mod.rs | 3 + quaint/src/connector/sqlite/wasm/mod.rs | 4 +- quaint/src/single.rs | 5 +- 13 files changed, 24 insertions(+), 540 deletions(-) delete mode 100644 quaint/src/connector/mssql/conversion.rs delete mode 100644 quaint/src/connector/mssql/error.rs diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index b182e60a4387..0aaa19aa463b 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -23,40 +23,6 @@ mod timeout; mod transaction; mod type_identifier; -// #[cfg(feature = "mssql-connector")] -// pub(crate) mod mssql; -// #[cfg(feature = "mssql")] -// pub(crate) mod mssql_wasm; -// #[cfg(feature = "mysql-connector")] -// pub(crate) mod mysql; -// #[cfg(feature = "mysql")] -// pub(crate) mod mysql_wasm; -// #[cfg(feature = "postgresql-connector")] -// pub(crate) mod postgres; -// #[cfg(feature = "postgresql")] -// pub(crate) mod postgres_wasm; -// #[cfg(feature = "sqlite-connector")] -// pub(crate) mod sqlite; -// #[cfg(feature = "sqlite")] -// pub(crate) mod sqlite_wasm; - -// #[cfg(feature = "mysql-connector")] -// pub use self::mysql::*; -// #[cfg(feature = "mysql")] -// pub use self::mysql_wasm::*; -// #[cfg(feature = "postgresql-connector")] -// pub use self::postgres::*; -// #[cfg(feature = "postgresql")] -// pub use self::postgres_wasm::*; -// #[cfg(feature = "mssql-connector")] -// pub use mssql::*; -// #[cfg(feature = "mssql")] -// pub use mssql_wasm::*; -// #[cfg(feature = "sqlite-connector")] -// pub use sqlite::*; -// #[cfg(feature = "sqlite")] -// pub use sqlite_wasm::*; - pub use self::result_set::*; pub use connection_info::*; pub use queryable::*; diff --git a/quaint/src/connector/mssql/conversion.rs b/quaint/src/connector/mssql/conversion.rs deleted file mode 100644 index 870654ad5de3..000000000000 --- a/quaint/src/connector/mssql/conversion.rs +++ /dev/null @@ -1,87 +0,0 @@ -use crate::ast::{Value, ValueType}; - -use bigdecimal::BigDecimal; -use std::{borrow::Cow, convert::TryFrom}; - -use tiberius::ToSql; -use tiberius::{ColumnData, FromSql, IntoSql}; - -impl<'a> IntoSql<'a> for &'a Value<'a> { - fn into_sql(self) -> ColumnData<'a> { - match &self.typed { - ValueType::Int32(val) => val.into_sql(), - ValueType::Int64(val) => val.into_sql(), - ValueType::Float(val) => val.into_sql(), - ValueType::Double(val) => val.into_sql(), - ValueType::Text(val) => val.as_deref().into_sql(), - ValueType::Bytes(val) => val.as_deref().into_sql(), - ValueType::Enum(val, _) => val.as_deref().into_sql(), - ValueType::Boolean(val) => val.into_sql(), - ValueType::Char(val) => val.as_ref().map(|val| format!("{val}")).into_sql(), - ValueType::Xml(val) => val.as_deref().into_sql(), - ValueType::Array(_) | ValueType::EnumArray(_, _) => panic!("Arrays are not supported on SQL Server."), - ValueType::Numeric(val) => (*val).to_sql(), - ValueType::Json(val) => val.as_ref().map(|val| serde_json::to_string(&val).unwrap()).into_sql(), - ValueType::Uuid(val) => val.into_sql(), - ValueType::DateTime(val) => val.into_sql(), - ValueType::Date(val) => val.into_sql(), - ValueType::Time(val) => val.into_sql(), - } - } -} - -impl TryFrom> for Value<'static> { - type Error = crate::error::Error; - - fn try_from(cd: ColumnData<'static>) -> crate::Result { - let res = match cd { - ColumnData::U8(num) => ValueType::Int32(num.map(i32::from)), - ColumnData::I16(num) => ValueType::Int32(num.map(i32::from)), - ColumnData::I32(num) => ValueType::Int32(num.map(i32::from)), - ColumnData::I64(num) => ValueType::Int64(num.map(i64::from)), - ColumnData::F32(num) => ValueType::Float(num), - ColumnData::F64(num) => ValueType::Double(num), - ColumnData::Bit(b) => ValueType::Boolean(b), - ColumnData::String(s) => ValueType::Text(s), - ColumnData::Guid(uuid) => ValueType::Uuid(uuid), - ColumnData::Binary(bytes) => ValueType::Bytes(bytes), - numeric @ ColumnData::Numeric(_) => ValueType::Numeric(BigDecimal::from_sql(&numeric)?), - dt @ ColumnData::DateTime(_) => { - use tiberius::time::chrono::{DateTime, NaiveDateTime, Utc}; - - let dt = NaiveDateTime::from_sql(&dt)?.map(|dt| DateTime::::from_utc(dt, Utc)); - ValueType::DateTime(dt) - } - dt @ ColumnData::SmallDateTime(_) => { - use tiberius::time::chrono::{DateTime, NaiveDateTime, Utc}; - - let dt = NaiveDateTime::from_sql(&dt)?.map(|dt| DateTime::::from_utc(dt, Utc)); - ValueType::DateTime(dt) - } - dt @ ColumnData::Time(_) => { - use tiberius::time::chrono::NaiveTime; - - ValueType::Time(NaiveTime::from_sql(&dt)?) - } - dt @ ColumnData::Date(_) => { - use tiberius::time::chrono::NaiveDate; - ValueType::Date(NaiveDate::from_sql(&dt)?) - } - dt @ ColumnData::DateTime2(_) => { - use tiberius::time::chrono::{DateTime, NaiveDateTime, Utc}; - - let dt = NaiveDateTime::from_sql(&dt)?.map(|dt| DateTime::::from_utc(dt, Utc)); - - ValueType::DateTime(dt) - } - dt @ ColumnData::DateTimeOffset(_) => { - use tiberius::time::chrono::{DateTime, Utc}; - - ValueType::DateTime(DateTime::::from_sql(&dt)?) - } - ColumnData::Xml(cow) => ValueType::Xml(cow.map(|xml_data| Cow::Owned(xml_data.into_owned().into_string()))), - }; - - Ok(Value::from(res)) - } -} diff --git a/quaint/src/connector/mssql/error.rs b/quaint/src/connector/mssql/error.rs deleted file mode 100644 index f9b6f5e95ab6..000000000000 --- a/quaint/src/connector/mssql/error.rs +++ /dev/null @@ -1,247 +0,0 @@ -use crate::error::{DatabaseConstraint, Error, ErrorKind}; -use tiberius::error::IoErrorKind; - -impl From for Error { - fn from(e: tiberius::error::Error) -> Error { - match e { - tiberius::error::Error::Io { - kind: IoErrorKind::UnexpectedEof, - message, - } => { - let mut builder = Error::builder(ErrorKind::ConnectionClosed); - builder.set_original_message(message); - builder.build() - } - e @ tiberius::error::Error::Io { .. } => Error::builder(ErrorKind::ConnectionError(e.into())).build(), - tiberius::error::Error::Tls(message) => { - let message = format!( - "The TLS settings didn't allow the connection to be established. Please review your connection string. (error: {message})" - ); - - Error::builder(ErrorKind::TlsError { message }).build() - } - tiberius::error::Error::Server(e) if [3902u32, 3903u32, 3971u32].iter().any(|code| e.code() == *code) => { - let kind = ErrorKind::TransactionAlreadyClosed(e.message().to_string()); - - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 8169 => { - let kind = ErrorKind::conversion(e.message().to_string()); - - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 18456 => { - let user = e.message().split('\'').nth(1).into(); - let kind = ErrorKind::AuthenticationFailed { user }; - - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 4060 => { - let db_name = e.message().split('"').nth(1).into(); - let kind = ErrorKind::DatabaseDoesNotExist { db_name }; - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 515 => { - let constraint = e - .message() - .split_whitespace() - .nth(7) - .and_then(|s| s.split('\'').nth(1)) - .map(|s| DatabaseConstraint::fields(Some(s))) - .unwrap_or(DatabaseConstraint::CannotParse); - - let kind = ErrorKind::NullConstraintViolation { constraint }; - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 1801 => { - let db_name = e.message().split('\'').nth(1).into(); - let kind = ErrorKind::DatabaseAlreadyExists { db_name }; - - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 2627 => { - let constraint = e - .message() - .split(". ") - .nth(1) - .and_then(|s| s.split(' ').last()) - .and_then(|s| s.split('\'').nth(1)) - .map(ToString::to_string) - .map(DatabaseConstraint::Index) - .unwrap_or(DatabaseConstraint::CannotParse); - - let kind = ErrorKind::UniqueConstraintViolation { constraint }; - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 547 => { - let constraint = e - .message() - .split('.') - .next() - .and_then(|s| s.split_whitespace().last()) - .and_then(|s| s.split('\"').nth(1)) - .map(ToString::to_string) - .map(DatabaseConstraint::Index) - .unwrap_or(DatabaseConstraint::CannotParse); - - let kind = ErrorKind::ForeignKeyConstraintViolation { constraint }; - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 1505 => { - let constraint = e - .message() - .split('\'') - .nth(3) - .map(ToString::to_string) - .map(DatabaseConstraint::Index) - .unwrap_or(DatabaseConstraint::CannotParse); - - let kind = ErrorKind::UniqueConstraintViolation { constraint }; - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 2601 => { - let constraint = e - .message() - .split_whitespace() - .nth(11) - .and_then(|s| s.split('\'').nth(1)) - .map(ToString::to_string) - .map(DatabaseConstraint::Index) - .unwrap_or(DatabaseConstraint::CannotParse); - - let kind = ErrorKind::UniqueConstraintViolation { constraint }; - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 1801 => { - let db_name = e.message().split('\'').nth(1).into(); - let kind = ErrorKind::DatabaseAlreadyExists { db_name }; - - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 2628 => { - let column = e.message().split('\'').nth(3).into(); - let kind = ErrorKind::LengthMismatch { column }; - - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 208 => { - let table = e - .message() - .split_whitespace() - .nth(3) - .and_then(|s| s.split('\'').nth(1)) - .into(); - - let kind = ErrorKind::TableDoesNotExist { table }; - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 207 => { - let column = e - .message() - .split_whitespace() - .nth(3) - .and_then(|s| s.split('\'').nth(1)) - .into(); - - let kind = ErrorKind::ColumnNotFound { column }; - let mut builder = Error::builder(kind); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 1205 => { - let mut builder = Error::builder(ErrorKind::TransactionWriteConflict); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) if e.code() == 3903 => { - let mut builder = Error::builder(ErrorKind::RollbackWithoutBegin); - - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - tiberius::error::Error::Server(e) => { - let kind = ErrorKind::QueryError(e.clone().into()); - - let mut builder = Error::builder(kind); - builder.set_original_code(format!("{}", e.code())); - builder.set_original_message(e.message().to_string()); - - builder.build() - } - e => Error::builder(ErrorKind::QueryError(e.into())).build(), - } - } -} diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index a1ea3bd5394d..6a1019c4f594 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -1,3 +1,6 @@ +//! Definitions for the MSSQL connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `mssql-connector` feature. mod conversion; mod error; diff --git a/quaint/src/connector/mssql/wasm/mod.rs b/quaint/src/connector/mssql/wasm/mod.rs index 69f1f46f7d21..5a25a32836c2 100644 --- a/quaint/src/connector/mssql/wasm/mod.rs +++ b/quaint/src/connector/mssql/wasm/mod.rs @@ -1,5 +1,5 @@ -///! Wasm-compatible definitions for the MSSQL connector. -/// This module is only available with the `mssql` feature. +//! Wasm-compatible definitions for the MSSQL connector. +//! This module is only available with the `mssql` feature. pub(crate) mod common; pub use common::MssqlUrl; diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index 1a9652b628f8..234f7fb3d74f 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -1,3 +1,6 @@ +//! Definitions for the MySQL connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `mysql-connector` feature. mod conversion; mod error; @@ -63,21 +66,6 @@ impl MysqlUrl { } } -#[derive(Debug, Clone)] -pub(crate) struct MysqlUrlQueryParams { - ssl_opts: my::SslOpts, - connection_limit: Option, - use_ssl: bool, - socket: Option, - socket_timeout: Option, - connect_timeout: Option, - pool_timeout: Option, - max_connection_lifetime: Option, - max_idle_connection_lifetime: Option, - prefer_socket: Option, - statement_cache_size: usize, -} - /// A connector interface for the MySQL database. #[derive(Debug)] pub struct Mysql { diff --git a/quaint/src/connector/mysql/wasm/mod.rs b/quaint/src/connector/mysql/wasm/mod.rs index da9a57a53876..4f73f82031d5 100644 --- a/quaint/src/connector/mysql/wasm/mod.rs +++ b/quaint/src/connector/mysql/wasm/mod.rs @@ -1,5 +1,5 @@ -///! Wasm-compatible definitions for the MySQL connector. -/// /// This module is only available with the `mysql` feature. +//! Wasm-compatible definitions for the MySQL connector. +//! This module is only available with the `mysql` feature. pub(crate) mod common; pub mod error; diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index 8f1645ca4123..a6628086aaae 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -1,11 +1,11 @@ -///! Definitions for the Postgres connector. -/// This module is not compatible with wasm32-* targets. -/// This module is only available with the `postgresql-connector` feature. +//! Definitions for the Postgres connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `postgresql-connector` feature. mod conversion; mod error; +pub(crate) use crate::connector::postgres::wasm::common::PostgresUrl; use crate::connector::postgres::wasm::common::{Hidden, SslAcceptMode, SslParams}; -pub(crate) use crate::connector::postgres::wasm::common::{PostgresFlavour, PostgresUrl}; use crate::connector::{timeout, IsolationLevel, Transaction}; use crate::{ @@ -670,6 +670,7 @@ fn is_safe_identifier(ident: &str) -> bool { #[cfg(test)] mod tests { use super::*; + pub(crate) use crate::connector::postgres::wasm::common::PostgresFlavour; use crate::tests::test_api::postgres::CONN_STR; use crate::tests::test_api::CRDB_CONN_STR; use crate::{connector::Queryable, error::*, single::Quaint}; diff --git a/quaint/src/connector/postgres/wasm/common.rs b/quaint/src/connector/postgres/wasm/common.rs index 46d327c0183d..88145beb40de 100644 --- a/quaint/src/connector/postgres/wasm/common.rs +++ b/quaint/src/connector/postgres/wasm/common.rs @@ -468,145 +468,3 @@ impl Display for SetSearchPath<'_> { Ok(()) } } - -/// Sorted list of CockroachDB's reserved keywords. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_KEYWORDS: [&str; 79] = [ - "all", - "analyse", - "analyze", - "and", - "any", - "array", - "as", - "asc", - "asymmetric", - "both", - "case", - "cast", - "check", - "collate", - "column", - "concurrently", - "constraint", - "create", - "current_catalog", - "current_date", - "current_role", - "current_schema", - "current_time", - "current_timestamp", - "current_user", - "default", - "deferrable", - "desc", - "distinct", - "do", - "else", - "end", - "except", - "false", - "fetch", - "for", - "foreign", - "from", - "grant", - "group", - "having", - "in", - "initially", - "intersect", - "into", - "lateral", - "leading", - "limit", - "localtime", - "localtimestamp", - "not", - "null", - "offset", - "on", - "only", - "or", - "order", - "placing", - "primary", - "references", - "returning", - "select", - "session_user", - "some", - "symmetric", - "table", - "then", - "to", - "trailing", - "true", - "union", - "unique", - "user", - "using", - "variadic", - "when", - "where", - "window", - "with", -]; - -/// Sorted list of CockroachDB's reserved type function names. -/// Taken from https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#keywords -const RESERVED_TYPE_FUNCTION_NAMES: [&str; 18] = [ - "authorization", - "collation", - "cross", - "full", - "ilike", - "inner", - "is", - "isnull", - "join", - "left", - "like", - "natural", - "none", - "notnull", - "outer", - "overlaps", - "right", - "similar", -]; - -/// Returns true if a Postgres identifier is considered "safe". -/// -/// In this context, "safe" means that the value of an identifier would be the same quoted and unquoted or that it's not part of reserved keywords. In other words, that it does _not_ need to be quoted. -/// -/// Spec can be found here: https://www.postgresql.org/docs/current/sql-syntax-lexical.html#SQL-SYNTAX-IDENTIFIERS -/// or here: https://www.cockroachlabs.com/docs/stable/keywords-and-identifiers.html#rules-for-identifiers -fn is_safe_identifier(ident: &str) -> bool { - if ident.is_empty() { - return false; - } - - // 1. Not equal any SQL keyword unless the keyword is accepted by the element's syntax. For example, name accepts Unreserved or Column Name keywords. - if RESERVED_KEYWORDS.binary_search(&ident).is_ok() || RESERVED_TYPE_FUNCTION_NAMES.binary_search(&ident).is_ok() { - return false; - } - - let mut chars = ident.chars(); - - let first = chars.next().unwrap(); - - // 2. SQL identifiers must begin with a letter (a-z, but also letters with diacritical marks and non-Latin letters) or an underscore (_). - if (!first.is_alphabetic() || !first.is_lowercase()) && first != '_' { - return false; - } - - for c in chars { - // 3. Subsequent characters in an identifier can be letters, underscores, digits (0-9), or dollar signs ($). - if (!c.is_alphabetic() || !c.is_lowercase()) && c != '_' && !c.is_ascii_digit() && c != '$' { - return false; - } - } - - true -} diff --git a/quaint/src/connector/postgres/wasm/mod.rs b/quaint/src/connector/postgres/wasm/mod.rs index 5b330861a199..859de8f6fd3c 100644 --- a/quaint/src/connector/postgres/wasm/mod.rs +++ b/quaint/src/connector/postgres/wasm/mod.rs @@ -1,5 +1,5 @@ -///! Wasm-compatible definitions for the Postgres connector. -/// /// This module is only available with the `postgresql` feature. +//! Wasm-compatible definitions for the Postgres connector. +//! This module is only available with the `postgresql` feature. pub(crate) mod common; pub mod error; diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index e11f6cd021bc..66f0e6d840df 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -1,3 +1,6 @@ +//! Definitions for the SQLite connector. +//! This module is not compatible with wasm32-* targets. +//! This module is only available with the `sqlite-connector` feature. mod conversion; mod error; diff --git a/quaint/src/connector/sqlite/wasm/mod.rs b/quaint/src/connector/sqlite/wasm/mod.rs index 0dbbcd76daec..45307cccd0a3 100644 --- a/quaint/src/connector/sqlite/wasm/mod.rs +++ b/quaint/src/connector/sqlite/wasm/mod.rs @@ -1,4 +1,4 @@ -///! Wasm-compatible definitions for the SQLite connector. -/// /// This module is only available with the `sqlite` feature. +//! Wasm-compatible definitions for the SQLite connector. +//! This module is only available with the `sqlite` feature. pub(crate) mod common; pub mod error; diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 2f234e40fd74..12bcf65c460a 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -1,7 +1,5 @@ //! A single connection abstraction to a SQL database. -#[cfg(feature = "sqlite")] -use crate::connector::DEFAULT_SQLITE_SCHEMA_NAME; use crate::{ ast, connector::{self, impl_default_TransactionCapable, ConnectionInfo, IsolationLevel, Queryable, TransactionCapable}, @@ -9,7 +7,6 @@ use crate::{ use async_trait::async_trait; use std::{fmt, sync::Arc}; -#[cfg(feature = "sqlite")] use std::convert::TryFrom; /// The main entry point and an abstraction over a database connection. @@ -169,6 +166,8 @@ impl Quaint { #[cfg(feature = "sqlite-connector")] /// Open a new SQLite database in memory. pub fn new_in_memory() -> crate::Result { + use crate::connector::DEFAULT_SQLITE_SCHEMA_NAME; + Ok(Quaint { inner: Arc::new(connector::Sqlite::new_in_memory()?), connection_info: Arc::new(ConnectionInfo::InMemorySqlite { From 6a1f733241372c0459797a215c11443d0e130bcf Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 13 Nov 2023 14:35:42 +0100 Subject: [PATCH 08/34] chore(quaint): fix clippy when compiling to wasm32-unknown-unknown --- quaint/src/connector/mssql/wasm/common.rs | 2 ++ quaint/src/connector/mysql/wasm/common.rs | 2 ++ quaint/src/connector/postgres/wasm/common.rs | 2 ++ quaint/src/connector/sqlite/wasm/common.rs | 2 ++ quaint/src/error.rs | 2 +- quaint/src/single.rs | 2 ++ 6 files changed, 11 insertions(+), 1 deletion(-) diff --git a/quaint/src/connector/mssql/wasm/common.rs b/quaint/src/connector/mssql/wasm/common.rs index 5b6ee881d3e9..42cc0868f9bf 100644 --- a/quaint/src/connector/mssql/wasm/common.rs +++ b/quaint/src/connector/mssql/wasm/common.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + use crate::{ connector::IsolationLevel, error::{Error, ErrorKind}, diff --git a/quaint/src/connector/mysql/wasm/common.rs b/quaint/src/connector/mysql/wasm/common.rs index fe60fd24cfc1..58598d6509ac 100644 --- a/quaint/src/connector/mysql/wasm/common.rs +++ b/quaint/src/connector/mysql/wasm/common.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + use crate::error::{Error, ErrorKind}; use percent_encoding::percent_decode; use std::{ diff --git a/quaint/src/connector/postgres/wasm/common.rs b/quaint/src/connector/postgres/wasm/common.rs index 88145beb40de..c90826c40548 100644 --- a/quaint/src/connector/postgres/wasm/common.rs +++ b/quaint/src/connector/postgres/wasm/common.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + use std::{ borrow::Cow, fmt::{Debug, Display}, diff --git a/quaint/src/connector/sqlite/wasm/common.rs b/quaint/src/connector/sqlite/wasm/common.rs index 10c174480785..46fb5c08f669 100644 --- a/quaint/src/connector/sqlite/wasm/common.rs +++ b/quaint/src/connector/sqlite/wasm/common.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + use crate::error::{Error, ErrorKind}; use std::{convert::TryFrom, path::Path, time::Duration}; diff --git a/quaint/src/error.rs b/quaint/src/error.rs index 705bb6b37ee0..f6ae3b3ee34a 100644 --- a/quaint/src/error.rs +++ b/quaint/src/error.rs @@ -282,7 +282,7 @@ pub enum ErrorKind { } impl ErrorKind { - #[cfg(feature = "mysql")] + #[cfg(feature = "mysql-connector")] pub(crate) fn value_out_of_range(msg: impl Into) -> Self { Self::ValueOutOfRange { message: msg.into() } } diff --git a/quaint/src/single.rs b/quaint/src/single.rs index 12bcf65c460a..e4e72ab614fa 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -7,6 +7,7 @@ use crate::{ use async_trait::async_trait; use std::{fmt, sync::Arc}; +#[cfg(feature = "sqlite-connector")] use std::convert::TryFrom; /// The main entry point and an abstraction over a database connection. @@ -124,6 +125,7 @@ impl Quaint { /// - `isolationLevel` the transaction isolation level. Possible values: /// `READ UNCOMMITTED`, `READ COMMITTED`, `REPEATABLE READ`, `SNAPSHOT`, /// `SERIALIZABLE`. + #[cfg_attr(target_arch = "wasm32", allow(unused_variables))] #[allow(unreachable_code)] pub async fn new(url_str: &str) -> crate::Result { let inner = match url_str { From e61bf75be0c36fd603e37441965ab8935c99c487 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Mon, 13 Nov 2023 16:01:12 +0100 Subject: [PATCH 09/34] chore(quaint): update README --- quaint/README.md | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/quaint/README.md b/quaint/README.md index 92033db269b1..3a9b41c65751 100644 --- a/quaint/README.md +++ b/quaint/README.md @@ -16,9 +16,13 @@ Quaint is an abstraction over certain SQL databases. It provides: ### Feature flags - `mysql`: Support for MySQL databases. + - On non-WebAssembly targets, choose `mysql-connector` instead. - `postgresql`: Support for PostgreSQL databases. + - On non-WebAssembly targets, choose `postgresql-connector` instead. - `sqlite`: Support for SQLite databases. + - On non-WebAssembly targets, choose `sqlite-connector` instead. - `mssql`: Support for Microsoft SQL Server databases. + - On non-WebAssembly targets, choose `mssql-connector` instead. - `pooled`: A connection pool in `pooled::Quaint`. - `vendored-openssl`: Statically links against a vendored OpenSSL library on non-Windows or non-Apple platforms. From 257c4c86e10bae7e61a0a32d5b5069c3f84f407f Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 14 Nov 2023 09:45:23 +0100 Subject: [PATCH 10/34] chore(quaint): rename "*-connector" feature flag to "*-native" --- Cargo.toml | 2 +- quaint/Cargo.toml | 20 ++++++------- quaint/README.md | 8 +++--- quaint/src/connector.rs | 20 ++++--------- quaint/src/connector/mssql.rs | 2 +- quaint/src/connector/mssql/native/mod.rs | 2 +- quaint/src/connector/mysql.rs | 2 +- quaint/src/connector/mysql/native/mod.rs | 2 +- quaint/src/connector/mysql/wasm/common.rs | 12 ++++---- quaint/src/connector/postgres.rs | 2 +- quaint/src/connector/postgres/native/mod.rs | 2 +- quaint/src/connector/postgres/wasm/common.rs | 18 ++++++------ quaint/src/connector/sqlite.rs | 2 +- quaint/src/connector/sqlite/native/mod.rs | 2 +- quaint/src/error.rs | 2 +- quaint/src/pooled/manager.rs | 30 ++++++++++---------- quaint/src/single.rs | 12 ++++---- 17 files changed, 66 insertions(+), 74 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 66f4399ff6db..b32a1a85cf18 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -68,7 +68,7 @@ features = [ "pooled", "postgresql", "sqlite", - "connectors", + "native", ] [profile.dev.package.backtrace] diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index abe9fece9746..7c804add2f5e 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -29,21 +29,21 @@ docs = [] # way to access database-specific methods when you need extra control. expose-drivers = [] -connectors = [ - "postgresql-connector", - "mysql-connector", - "mssql-connector", - "sqlite-connector", +native = [ + "postgresql-native", + "mysql-native", + "mssql-native", + "sqlite-native", ] -all = ["connectors", "pooled"] +all = ["native", "pooled"] vendored-openssl = [ "postgres-native-tls/vendored-openssl", "mysql_async/vendored-openssl", ] -postgresql-connector = [ +postgresql-native = [ "postgresql", "native-tls", "tokio-postgres", @@ -57,7 +57,7 @@ postgresql-connector = [ ] postgresql = [] -mssql-connector = [ +mssql-native = [ "mssql", "tiberius", "tokio-util", @@ -66,11 +66,11 @@ mssql-connector = [ ] mssql = [] -mysql-connector = ["mysql", "mysql_async", "tokio/time", "lru-cache"] +mysql-native = ["mysql", "mysql_async", "tokio/time", "lru-cache"] mysql = ["chrono/std"] pooled = ["mobc"] -sqlite-connector = ["sqlite", "rusqlite/bundled", "tokio/sync"] +sqlite-native = ["sqlite", "rusqlite/bundled", "tokio/sync"] sqlite = ["rusqlite"] fmt-sql = ["sqlformat"] diff --git a/quaint/README.md b/quaint/README.md index 3a9b41c65751..03108d9090d3 100644 --- a/quaint/README.md +++ b/quaint/README.md @@ -16,13 +16,13 @@ Quaint is an abstraction over certain SQL databases. It provides: ### Feature flags - `mysql`: Support for MySQL databases. - - On non-WebAssembly targets, choose `mysql-connector` instead. + - On non-WebAssembly targets, choose `mysql-native` instead. - `postgresql`: Support for PostgreSQL databases. - - On non-WebAssembly targets, choose `postgresql-connector` instead. + - On non-WebAssembly targets, choose `postgresql-native` instead. - `sqlite`: Support for SQLite databases. - - On non-WebAssembly targets, choose `sqlite-connector` instead. + - On non-WebAssembly targets, choose `sqlite-native` instead. - `mssql`: Support for Microsoft SQL Server databases. - - On non-WebAssembly targets, choose `mssql-connector` instead. + - On non-WebAssembly targets, choose `mssql-native` instead. - `pooled`: A connection pool in `pooled::Quaint`. - `vendored-openssl`: Statically links against a vendored OpenSSL library on non-Windows or non-Apple platforms. diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index 0aaa19aa463b..7903d23931c0 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -14,11 +14,7 @@ mod connection_info; pub mod metrics; mod queryable; mod result_set; -#[cfg(any( - feature = "mssql-connector", - feature = "postgresql-connector", - feature = "mysql-connector" -))] +#[cfg(any(feature = "mssql-native", feature = "postgresql-native", feature = "mysql-native"))] mod timeout; mod transaction; mod type_identifier; @@ -27,11 +23,7 @@ pub use self::result_set::*; pub use connection_info::*; pub use queryable::*; pub use transaction::*; -#[cfg(any( - feature = "mssql-connector", - feature = "postgresql-connector", - feature = "mysql-connector" -))] +#[cfg(any(feature = "mssql-native", feature = "postgresql-native", feature = "mysql-native"))] #[allow(unused_imports)] pub(crate) use type_identifier::*; @@ -39,28 +31,28 @@ pub use self::metrics::query; #[cfg(feature = "postgresql")] pub(crate) mod postgres; -#[cfg(feature = "postgresql-connector")] +#[cfg(feature = "postgresql-native")] pub use postgres::native::*; #[cfg(feature = "postgresql")] pub use postgres::wasm::common::*; #[cfg(feature = "mysql")] pub(crate) mod mysql; -#[cfg(feature = "mysql-connector")] +#[cfg(feature = "mysql-native")] pub use mysql::native::*; #[cfg(feature = "mysql")] pub use mysql::wasm::common::*; #[cfg(feature = "sqlite")] pub(crate) mod sqlite; -#[cfg(feature = "sqlite-connector")] +#[cfg(feature = "sqlite-native")] pub use sqlite::native::*; #[cfg(feature = "sqlite")] pub use sqlite::wasm::common::*; #[cfg(feature = "mssql")] pub(crate) mod mssql; -#[cfg(feature = "mssql-connector")] +#[cfg(feature = "mssql-native")] pub use mssql::native::*; #[cfg(feature = "mssql")] pub use mssql::wasm::common::*; diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index ea681bd08d18..c83b5f1f7266 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -3,5 +3,5 @@ pub use wasm::common::MssqlUrl; #[cfg(feature = "mssql")] pub(crate) mod wasm; -#[cfg(feature = "mssql-connector")] +#[cfg(feature = "mssql-native")] pub(crate) mod native; diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index 6a1019c4f594..8458935814b4 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -1,6 +1,6 @@ //! Definitions for the MSSQL connector. //! This module is not compatible with wasm32-* targets. -//! This module is only available with the `mssql-connector` feature. +//! This module is only available with the `mssql-native` feature. mod conversion; mod error; diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index 1794cc738b1e..1e52af6a83a0 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -4,5 +4,5 @@ pub use wasm::error::MysqlError; #[cfg(feature = "mysql")] pub(crate) mod wasm; -#[cfg(feature = "mysql-connector")] +#[cfg(feature = "mysql-native")] pub(crate) mod native; diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index 234f7fb3d74f..e72a2c47a9a1 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -1,6 +1,6 @@ //! Definitions for the MySQL connector. //! This module is not compatible with wasm32-* targets. -//! This module is only available with the `mysql-connector` feature. +//! This module is only available with the `mysql-native` feature. mod conversion; mod error; diff --git a/quaint/src/connector/mysql/wasm/common.rs b/quaint/src/connector/mysql/wasm/common.rs index 58598d6509ac..c17b2224c0ef 100644 --- a/quaint/src/connector/mysql/wasm/common.rs +++ b/quaint/src/connector/mysql/wasm/common.rs @@ -123,7 +123,7 @@ impl MysqlUrl { } fn parse_query_params(url: &Url) -> Result { - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] let mut ssl_opts = { let mut ssl_opts = mysql_async::SslOpts::default(); ssl_opts = ssl_opts.with_danger_accept_invalid_certs(true); @@ -159,7 +159,7 @@ impl MysqlUrl { "sslcert" => { use_ssl = true; - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] { ssl_opts = ssl_opts.with_root_cert_path(Some(Path::new(&*v).to_path_buf())); } @@ -219,7 +219,7 @@ impl MysqlUrl { use_ssl = true; match v.as_ref() { "strict" => { - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] { ssl_opts = ssl_opts.with_danger_accept_invalid_certs(false); } @@ -263,7 +263,7 @@ impl MysqlUrl { // Wrapping this in a block, as attributes on expressions are still experimental // See: https://github.com/rust-lang/rust/issues/15701 - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] { ssl_opts = match identity { Some((Some(path), Some(pw))) => { @@ -279,7 +279,7 @@ impl MysqlUrl { } Ok(MysqlUrlQueryParams { - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] ssl_opts, connection_limit, use_ssl, @@ -313,6 +313,6 @@ pub(crate) struct MysqlUrlQueryParams { pub(crate) prefer_socket: Option, pub(crate) statement_cache_size: usize, - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] pub(crate) ssl_opts: mysql_async::SslOpts, } diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 0f4da84a7c67..73a8547b8a65 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -4,5 +4,5 @@ pub use wasm::error::PostgresError; #[cfg(feature = "postgresql")] pub(crate) mod wasm; -#[cfg(feature = "postgresql-connector")] +#[cfg(feature = "postgresql-native")] pub(crate) mod native; diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index a6628086aaae..fbb4760ed19f 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -1,6 +1,6 @@ //! Definitions for the Postgres connector. //! This module is not compatible with wasm32-* targets. -//! This module is only available with the `postgresql-connector` feature. +//! This module is only available with the `postgresql-native` feature. mod conversion; mod error; diff --git a/quaint/src/connector/postgres/wasm/common.rs b/quaint/src/connector/postgres/wasm/common.rs index c90826c40548..7b9b3aafabb4 100644 --- a/quaint/src/connector/postgres/wasm/common.rs +++ b/quaint/src/connector/postgres/wasm/common.rs @@ -11,7 +11,7 @@ use url::{Host, Url}; use crate::error::{Error, ErrorKind}; -#[cfg(feature = "postgresql-connector")] +#[cfg(feature = "postgresql-native")] use tokio_postgres::config::{ChannelBinding, SslMode}; #[derive(Clone)] @@ -211,9 +211,9 @@ impl PostgresUrl { } fn parse_query_params(url: &Url) -> Result { - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] let mut ssl_mode = SslMode::Prefer; - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] let mut channel_binding = ChannelBinding::Prefer; let mut connection_limit = None; @@ -240,7 +240,7 @@ impl PostgresUrl { .parse() .map_err(|_| Error::builder(ErrorKind::InvalidConnectionArguments).build())?; } - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] "sslmode" => { match v.as_ref() { "disable" => ssl_mode = SslMode::Disable, @@ -348,7 +348,7 @@ impl PostgresUrl { "application_name" => { application_name = Some(v.to_string()); } - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] "channel_binding" => { match v.as_ref() { "disable" => channel_binding = ChannelBinding::Disable, @@ -390,9 +390,9 @@ impl PostgresUrl { max_idle_connection_lifetime, application_name, options, - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] channel_binding, - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] ssl_mode, }) } @@ -427,10 +427,10 @@ pub(crate) struct PostgresUrlQueryParams { pub(crate) application_name: Option, pub(crate) options: Option, - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] pub(crate) channel_binding: ChannelBinding, - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] pub(crate) ssl_mode: SslMode, } diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index 0e699c211878..45611aab9357 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -3,5 +3,5 @@ pub use wasm::error::SqliteError; #[cfg(feature = "sqlite")] pub(crate) mod wasm; -#[cfg(feature = "sqlite-connector")] +#[cfg(feature = "sqlite-native")] pub(crate) mod native; diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index 66f0e6d840df..bdf5c473fd4d 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -1,6 +1,6 @@ //! Definitions for the SQLite connector. //! This module is not compatible with wasm32-* targets. -//! This module is only available with the `sqlite-connector` feature. +//! This module is only available with the `sqlite-native` feature. mod conversion; mod error; diff --git a/quaint/src/error.rs b/quaint/src/error.rs index f6ae3b3ee34a..a77513876726 100644 --- a/quaint/src/error.rs +++ b/quaint/src/error.rs @@ -282,7 +282,7 @@ pub enum ErrorKind { } impl ErrorKind { - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] pub(crate) fn value_out_of_range(msg: impl Into) -> Self { Self::ValueOutOfRange { message: msg.into() } } diff --git a/quaint/src/pooled/manager.rs b/quaint/src/pooled/manager.rs index c31fd44fbcae..73441b7609ba 100644 --- a/quaint/src/pooled/manager.rs +++ b/quaint/src/pooled/manager.rs @@ -1,8 +1,8 @@ -#[cfg(feature = "mssql-connector")] +#[cfg(feature = "mssql-native")] use crate::connector::MssqlUrl; -#[cfg(feature = "mysql-connector")] +#[cfg(feature = "mysql-native")] use crate::connector::MysqlUrl; -#[cfg(feature = "postgresql-connector")] +#[cfg(feature = "postgresql-native")] use crate::connector::PostgresUrl; use crate::{ ast, @@ -97,7 +97,7 @@ impl Manager for QuaintManager { async fn connect(&self) -> crate::Result { let conn = match self { - #[cfg(feature = "sqlite-connector")] + #[cfg(feature = "sqlite-native")] QuaintManager::Sqlite { url, .. } => { use crate::connector::Sqlite; @@ -106,19 +106,19 @@ impl Manager for QuaintManager { Ok(Box::new(conn) as Self::Connection) } - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] QuaintManager::Mysql { url } => { use crate::connector::Mysql; Ok(Box::new(Mysql::new(url.clone()).await?) as Self::Connection) } - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] QuaintManager::Postgres { url } => { use crate::connector::PostgreSql; Ok(Box::new(PostgreSql::new(url.clone()).await?) as Self::Connection) } - #[cfg(feature = "mssql-connector")] + #[cfg(feature = "mssql-native")] QuaintManager::Mssql { url } => { use crate::connector::Mssql; Ok(Box::new(Mssql::new(url.clone()).await?) as Self::Connection) @@ -146,7 +146,7 @@ mod tests { use crate::pooled::Quaint; #[tokio::test] - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] async fn mysql_default_connection_limit() { let conn_string = std::env::var("TEST_MYSQL").expect("TEST_MYSQL connection string not set."); @@ -156,7 +156,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] async fn mysql_custom_connection_limit() { let conn_string = format!( "{}?connection_limit=10", @@ -169,7 +169,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] async fn psql_default_connection_limit() { let conn_string = std::env::var("TEST_PSQL").expect("TEST_PSQL connection string not set."); @@ -179,7 +179,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] async fn psql_custom_connection_limit() { let conn_string = format!( "{}?connection_limit=10", @@ -192,7 +192,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mssql-connector")] + #[cfg(feature = "mssql-native")] async fn mssql_default_connection_limit() { let conn_string = std::env::var("TEST_MSSQL").expect("TEST_MSSQL connection string not set."); @@ -202,7 +202,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "mssql-connector")] + #[cfg(feature = "mssql-native")] async fn mssql_custom_connection_limit() { let conn_string = format!( "{};connectionLimit=10", @@ -215,7 +215,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "sqlite-connector")] + #[cfg(feature = "sqlite-native")] async fn test_default_connection_limit() { let conn_string = "file:db/test.db".to_string(); let pool = Quaint::builder(&conn_string).unwrap().build(); @@ -224,7 +224,7 @@ mod tests { } #[tokio::test] - #[cfg(feature = "sqlite-connector")] + #[cfg(feature = "sqlite-native")] async fn test_custom_connection_limit() { let conn_string = "file:db/test.db?connection_limit=10".to_string(); let pool = Quaint::builder(&conn_string).unwrap().build(); diff --git a/quaint/src/single.rs b/quaint/src/single.rs index e4e72ab614fa..1a4dbdf52a61 100644 --- a/quaint/src/single.rs +++ b/quaint/src/single.rs @@ -7,7 +7,7 @@ use crate::{ use async_trait::async_trait; use std::{fmt, sync::Arc}; -#[cfg(feature = "sqlite-connector")] +#[cfg(feature = "sqlite-native")] use std::convert::TryFrom; /// The main entry point and an abstraction over a database connection. @@ -129,27 +129,27 @@ impl Quaint { #[allow(unreachable_code)] pub async fn new(url_str: &str) -> crate::Result { let inner = match url_str { - #[cfg(feature = "sqlite-connector")] + #[cfg(feature = "sqlite-native")] s if s.starts_with("file") => { let params = connector::SqliteParams::try_from(s)?; let sqlite = connector::Sqlite::new(¶ms.file_path)?; Arc::new(sqlite) as Arc } - #[cfg(feature = "mysql-connector")] + #[cfg(feature = "mysql-native")] s if s.starts_with("mysql") => { let url = connector::MysqlUrl::new(url::Url::parse(s)?)?; let mysql = connector::Mysql::new(url).await?; Arc::new(mysql) as Arc } - #[cfg(feature = "postgresql-connector")] + #[cfg(feature = "postgresql-native")] s if s.starts_with("postgres") || s.starts_with("postgresql") => { let url = connector::PostgresUrl::new(url::Url::parse(s)?)?; let psql = connector::PostgreSql::new(url).await?; Arc::new(psql) as Arc } - #[cfg(feature = "mssql-connector")] + #[cfg(feature = "mssql-native")] s if s.starts_with("jdbc:sqlserver") | s.starts_with("sqlserver") => { let url = connector::MssqlUrl::new(s)?; let psql = connector::Mssql::new(url).await?; @@ -165,7 +165,7 @@ impl Quaint { Ok(Self { inner, connection_info }) } - #[cfg(feature = "sqlite-connector")] + #[cfg(feature = "sqlite-native")] /// Open a new SQLite database in memory. pub fn new_in_memory() -> crate::Result { use crate::connector::DEFAULT_SQLITE_SCHEMA_NAME; From 5ab6d9636220469772b7969d8c6db84701e6a196 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 14 Nov 2023 12:37:51 +0100 Subject: [PATCH 11/34] feat(quaint): enable pure Wasm SqliteError --- quaint/Cargo.toml | 2 +- quaint/src/connector/sqlite/native/error.rs | 17 +++++++++++++ quaint/src/connector/sqlite/wasm/error.rs | 28 ++++++--------------- quaint/src/connector/sqlite/wasm/ffi.rs | 7 ++++++ quaint/src/connector/sqlite/wasm/mod.rs | 1 + 5 files changed, 34 insertions(+), 21 deletions(-) create mode 100644 quaint/src/connector/sqlite/wasm/ffi.rs diff --git a/quaint/Cargo.toml b/quaint/Cargo.toml index 7c804add2f5e..52a7edf72aca 100644 --- a/quaint/Cargo.toml +++ b/quaint/Cargo.toml @@ -71,7 +71,7 @@ mysql = ["chrono/std"] pooled = ["mobc"] sqlite-native = ["sqlite", "rusqlite/bundled", "tokio/sync"] -sqlite = ["rusqlite"] +sqlite = [] fmt-sql = ["sqlformat"] diff --git a/quaint/src/connector/sqlite/native/error.rs b/quaint/src/connector/sqlite/native/error.rs index 9e2b2e7c3ea1..d09e2959ce28 100644 --- a/quaint/src/connector/sqlite/native/error.rs +++ b/quaint/src/connector/sqlite/native/error.rs @@ -2,6 +2,17 @@ use crate::connector::sqlite::wasm::error::SqliteError; use crate::error::*; +impl std::fmt::Display for SqliteError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!( + f, + "Error code {}: {}", + self.extended_code, + rusqlite::ffi::code_to_str(self.extended_code) + ) + } +} + impl From for Error { fn from(e: rusqlite::Error) -> Error { match e { @@ -47,3 +58,9 @@ impl From for Error { } } } + +impl From for Error { + fn from(e: rusqlite::types::FromSqlError) -> Error { + Error::builder(ErrorKind::ColumnReadFailure(e.into())).build() + } +} diff --git a/quaint/src/connector/sqlite/wasm/error.rs b/quaint/src/connector/sqlite/wasm/error.rs index 9cd0ef64e8a4..2c6ff11350fd 100644 --- a/quaint/src/connector/sqlite/wasm/error.rs +++ b/quaint/src/connector/sqlite/wasm/error.rs @@ -1,5 +1,3 @@ -use std::fmt; - use crate::error::*; #[derive(Debug)] @@ -8,14 +6,10 @@ pub struct SqliteError { pub message: Option, } -impl fmt::Display for SqliteError { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!( - f, - "Error code {}: {}", - self.extended_code, - rusqlite::ffi::code_to_str(self.extended_code) - ) +#[cfg(not(feature = "sqlite-native"))] +impl std::fmt::Display for SqliteError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "Error code {}", self.extended_code) } } @@ -35,7 +29,7 @@ impl From for Error { fn from(error: SqliteError) -> Self { match error { SqliteError { - extended_code: rusqlite::ffi::SQLITE_CONSTRAINT_UNIQUE | rusqlite::ffi::SQLITE_CONSTRAINT_PRIMARYKEY, + extended_code: super::ffi::SQLITE_CONSTRAINT_UNIQUE | super::ffi::SQLITE_CONSTRAINT_PRIMARYKEY, message: Some(description), } => { let constraint = description @@ -56,7 +50,7 @@ impl From for Error { } SqliteError { - extended_code: rusqlite::ffi::SQLITE_CONSTRAINT_NOTNULL, + extended_code: super::ffi::SQLITE_CONSTRAINT_NOTNULL, message: Some(description), } => { let constraint = description @@ -77,7 +71,7 @@ impl From for Error { } SqliteError { - extended_code: rusqlite::ffi::SQLITE_CONSTRAINT_FOREIGNKEY | rusqlite::ffi::SQLITE_CONSTRAINT_TRIGGER, + extended_code: super::ffi::SQLITE_CONSTRAINT_FOREIGNKEY | super::ffi::SQLITE_CONSTRAINT_TRIGGER, message: Some(description), } => { let mut builder = Error::builder(ErrorKind::ForeignKeyConstraintViolation { @@ -90,7 +84,7 @@ impl From for Error { builder.build() } - SqliteError { extended_code, message } if error.primary_code() == rusqlite::ffi::SQLITE_BUSY => { + SqliteError { extended_code, message } if error.primary_code() == super::ffi::SQLITE_BUSY => { let mut builder = Error::builder(ErrorKind::SocketTimeout); builder.set_original_code(format!("{extended_code}")); @@ -150,9 +144,3 @@ impl From for Error { } } } - -impl From for Error { - fn from(e: rusqlite::types::FromSqlError) -> Error { - Error::builder(ErrorKind::ColumnReadFailure(e.into())).build() - } -} diff --git a/quaint/src/connector/sqlite/wasm/ffi.rs b/quaint/src/connector/sqlite/wasm/ffi.rs new file mode 100644 index 000000000000..bddfd4354237 --- /dev/null +++ b/quaint/src/connector/sqlite/wasm/ffi.rs @@ -0,0 +1,7 @@ +//! This is a partial copy of `rusqlite::ffi::*`. +pub const SQLITE_BUSY: i32 = 5; +pub const SQLITE_CONSTRAINT_FOREIGNKEY: i32 = 787; +pub const SQLITE_CONSTRAINT_NOTNULL: i32 = 1299; +pub const SQLITE_CONSTRAINT_PRIMARYKEY: i32 = 1555; +pub const SQLITE_CONSTRAINT_TRIGGER: i32 = 1811; +pub const SQLITE_CONSTRAINT_UNIQUE: i32 = 2067; diff --git a/quaint/src/connector/sqlite/wasm/mod.rs b/quaint/src/connector/sqlite/wasm/mod.rs index 45307cccd0a3..662237af30a1 100644 --- a/quaint/src/connector/sqlite/wasm/mod.rs +++ b/quaint/src/connector/sqlite/wasm/mod.rs @@ -2,3 +2,4 @@ //! This module is only available with the `sqlite` feature. pub(crate) mod common; pub mod error; +mod ffi; From ab65c9539cb5bfef59fd2c3f2187ec83d415e3fd Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 14 Nov 2023 12:38:24 +0100 Subject: [PATCH 12/34] feat(query-connect): allow wasm32-unknown-unknown compilation --- libs/user-facing-errors/Cargo.toml | 2 +- query-engine/connectors/query-connector/Cargo.toml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/libs/user-facing-errors/Cargo.toml b/libs/user-facing-errors/Cargo.toml index 9900892209c6..3049a19712b1 100644 --- a/libs/user-facing-errors/Cargo.toml +++ b/libs/user-facing-errors/Cargo.toml @@ -11,7 +11,7 @@ backtrace = "0.3.40" tracing = "0.1" indoc.workspace = true itertools = "0.10" -quaint = { workspace = true, optional = true } +quaint = { path = "../../quaint", optional = true } [features] default = [] diff --git a/query-engine/connectors/query-connector/Cargo.toml b/query-engine/connectors/query-connector/Cargo.toml index d16771aa3daf..788b8ca65576 100644 --- a/query-engine/connectors/query-connector/Cargo.toml +++ b/query-engine/connectors/query-connector/Cargo.toml @@ -14,6 +14,6 @@ prisma-value = {path = "../../../libs/prisma-value"} serde.workspace = true serde_json.workspace = true thiserror = "1.0" -user-facing-errors = {path = "../../../libs/user-facing-errors"} +user-facing-errors = {path = "../../../libs/user-facing-errors", features = ["sql"]} uuid = "1" indexmap = "1.7" From cfb550743b7c39dadf87672141bf2bc16c9318d4 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 14 Nov 2023 14:54:54 +0100 Subject: [PATCH 13/34] feat(sql-query-connector): allow wasm32-unknown-unknown compilation --- .../connectors/sql-query-connector/Cargo.toml | 6 +++-- .../sql-query-connector/src/database/mod.rs | 24 ++++++++++++------- .../src/database/{ => native}/mssql.rs | 4 ++-- .../src/database/{ => native}/mysql.rs | 4 ++-- .../src/database/{ => native}/postgresql.rs | 4 ++-- .../src/database/{ => native}/sqlite.rs | 4 ++-- .../src/database/operations/write.rs | 21 +++++++++++++++- .../connectors/sql-query-connector/src/lib.rs | 5 +++- 8 files changed, 52 insertions(+), 20 deletions(-) rename query-engine/connectors/sql-query-connector/src/database/{ => native}/mssql.rs (94%) rename query-engine/connectors/sql-query-connector/src/database/{ => native}/mysql.rs (95%) rename query-engine/connectors/sql-query-connector/src/database/{ => native}/postgresql.rs (95%) rename query-engine/connectors/sql-query-connector/src/database/{ => native}/sqlite.rs (96%) diff --git a/query-engine/connectors/sql-query-connector/Cargo.toml b/query-engine/connectors/sql-query-connector/Cargo.toml index 62d0be640761..fa9c32ef88e1 100644 --- a/query-engine/connectors/sql-query-connector/Cargo.toml +++ b/query-engine/connectors/sql-query-connector/Cargo.toml @@ -5,6 +5,8 @@ version = "0.1.0" [features] vendored-openssl = ["quaint/vendored-openssl"] + +# Enable Driver Adapters driver-adapters = [] [dependencies] @@ -18,13 +20,13 @@ once_cell = "1.3" rand = "0.7" serde_json = {version = "1.0", features = ["float_roundtrip"]} thiserror = "1.0" -tokio.workspace = true +tokio = { version = "1.0", features = ["macros", "time"] } tracing = "0.1" tracing-futures = "0.2" uuid.workspace = true opentelemetry = { version = "0.17", features = ["tokio"] } tracing-opentelemetry = "0.17.3" -quaint.workspace = true +quaint = { path = "../../../quaint" } cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } [dependencies.connector-interface] diff --git a/query-engine/connectors/sql-query-connector/src/database/mod.rs b/query-engine/connectors/sql-query-connector/src/database/mod.rs index 695db13b6620..7172e0101400 100644 --- a/query-engine/connectors/sql-query-connector/src/database/mod.rs +++ b/query-engine/connectors/sql-query-connector/src/database/mod.rs @@ -1,12 +1,16 @@ mod connection; #[cfg(feature = "driver-adapters")] mod js; -mod mssql; -mod mysql; -mod postgresql; -mod sqlite; mod transaction; +#[cfg(not(target_arch = "wasm32"))] +pub(crate) mod native { + pub(crate) mod mssql; + pub(crate) mod mysql; + pub(crate) mod postgresql; + pub(crate) mod sqlite; +} + pub(crate) mod operations; use async_trait::async_trait; @@ -14,10 +18,14 @@ use connector_interface::{error::ConnectorError, Connector}; #[cfg(feature = "driver-adapters")] pub use js::*; -pub use mssql::*; -pub use mysql::*; -pub use postgresql::*; -pub use sqlite::*; + +#[cfg(not(target_arch = "wasm32"))] +pub use native::{mssql::*, mysql::*, postgresql::*, sqlite::*}; + +// pub use mssql::*; +// pub use mysql::*; +// pub use postgresql::*; +// pub use sqlite::*; #[async_trait] pub trait FromSource { diff --git a/query-engine/connectors/sql-query-connector/src/database/mssql.rs b/query-engine/connectors/sql-query-connector/src/database/native/mssql.rs similarity index 94% rename from query-engine/connectors/sql-query-connector/src/database/mssql.rs rename to query-engine/connectors/sql-query-connector/src/database/native/mssql.rs index 9655d205e4ca..bdb6e2ee103c 100644 --- a/query-engine/connectors/sql-query-connector/src/database/mssql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/mssql.rs @@ -1,4 +1,4 @@ -use super::connection::SqlConnection; +use super::super::connection::SqlConnection; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -60,7 +60,7 @@ impl FromSource for Mssql { #[async_trait] impl Connector for Mssql { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::catch(self.connection_info.clone(), async move { + super::super::catch(self.connection_info.clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, &self.connection_info, self.features); diff --git a/query-engine/connectors/sql-query-connector/src/database/mysql.rs b/query-engine/connectors/sql-query-connector/src/database/native/mysql.rs similarity index 95% rename from query-engine/connectors/sql-query-connector/src/database/mysql.rs rename to query-engine/connectors/sql-query-connector/src/database/native/mysql.rs index deb3e6a4f35f..a1cd585c0005 100644 --- a/query-engine/connectors/sql-query-connector/src/database/mysql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/mysql.rs @@ -1,4 +1,4 @@ -use super::connection::SqlConnection; +use super::super::connection::SqlConnection; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -65,7 +65,7 @@ impl FromSource for Mysql { #[async_trait] impl Connector for Mysql { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::catch(self.connection_info.clone(), async move { + super::super::catch(self.connection_info.clone(), async move { let runtime_conn = self.pool.check_out().await?; // Note: `runtime_conn` must be `Sized`, as that's required by `TransactionCapable` diff --git a/query-engine/connectors/sql-query-connector/src/database/postgresql.rs b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs similarity index 95% rename from query-engine/connectors/sql-query-connector/src/database/postgresql.rs rename to query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs index 242b2b63090e..80025add046f 100644 --- a/query-engine/connectors/sql-query-connector/src/database/postgresql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs @@ -1,4 +1,4 @@ -use super::connection::SqlConnection; +use super::super::connection::SqlConnection; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -67,7 +67,7 @@ impl FromSource for PostgreSql { #[async_trait] impl Connector for PostgreSql { async fn get_connection<'a>(&'a self) -> connector_interface::Result> { - super::catch(self.connection_info.clone(), async move { + super::super::catch(self.connection_info.clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, &self.connection_info, self.features); Ok(Box::new(conn) as Box) diff --git a/query-engine/connectors/sql-query-connector/src/database/sqlite.rs b/query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs similarity index 96% rename from query-engine/connectors/sql-query-connector/src/database/sqlite.rs rename to query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs index 6be9faeac54d..b1250b18b2be 100644 --- a/query-engine/connectors/sql-query-connector/src/database/sqlite.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs @@ -1,4 +1,4 @@ -use super::connection::SqlConnection; +use super::super::connection::SqlConnection; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -80,7 +80,7 @@ fn invalid_file_path_error(file_path: &str, connection_info: &ConnectionInfo) -> #[async_trait] impl Connector for Sqlite { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::catch(self.connection_info().clone(), async move { + super::super::catch(self.connection_info().clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, self.connection_info(), self.features); diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs index 425f4ac1d4b3..611557c4f3ba 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/write.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/write.rs @@ -18,9 +18,28 @@ use std::{ ops::Deref, usize, }; -use tracing::log::trace; use user_facing_errors::query_engine::DatabaseConstraint; +#[cfg(target_arch = "wasm32")] +macro_rules! trace { + (target: $target:expr, $($arg:tt)+) => {{ + // No-op in WebAssembly + }}; + ($($arg:tt)+) => {{ + // No-op in WebAssembly + }}; +} + +#[cfg(not(target_arch = "wasm32"))] +macro_rules! trace { + (target: $target:expr, $($arg:tt)+) => { + tracing::log::trace!(target: $target, $($arg)+); + }; + ($($arg:tt)+) => { + tracing::log::trace!($($arg)+); + }; +} + async fn generate_id( conn: &dyn Queryable, id_field: &FieldSelection, diff --git a/query-engine/connectors/sql-query-connector/src/lib.rs b/query-engine/connectors/sql-query-connector/src/lib.rs index ed1528ded6b5..74c0a4aab5d3 100644 --- a/query-engine/connectors/sql-query-connector/src/lib.rs +++ b/query-engine/connectors/sql-query-connector/src/lib.rs @@ -22,9 +22,12 @@ mod value_ext; use self::{column_metadata::*, context::Context, query_ext::QueryExt, row::*}; use quaint::prelude::Queryable; +pub use database::FromSource; #[cfg(feature = "driver-adapters")] pub use database::{activate_driver_adapter, Js}; -pub use database::{FromSource, Mssql, Mysql, PostgreSql, Sqlite}; pub use error::SqlError; +#[cfg(not(target_arch = "wasm32"))] +pub use database::{Mssql, Mysql, PostgreSql, Sqlite}; + type Result = std::result::Result; From e7df5a3d0c219d7c59efa0822a3352ed30965e0f Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 14 Nov 2023 14:55:37 +0100 Subject: [PATCH 14/34] chore(query-engine-wasm): add currently unused local crates to test wasm32-unknown-unknown compilation --- query-engine/query-engine-wasm/Cargo.toml | 3 +++ 1 file changed, 3 insertions(+) diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index a8bc393aee3f..f65f31c2d63b 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -14,6 +14,9 @@ async-trait = "0.1" user-facing-errors = { path = "../../libs/user-facing-errors" } psl.workspace = true prisma-models = { path = "../prisma-models" } +quaint = { path = "../../quaint" } +connector = { path = "../connectors/query-connector", package = "query-connector" } +sql-query-connector = { path = "../connectors/sql-query-connector" } thiserror = "1" connection-string.workspace = true From 8c5d3dc999167815c0dbc3f4f5fe31557a6086e0 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 14 Nov 2023 14:55:48 +0100 Subject: [PATCH 15/34] chore: update Cargo.lock --- Cargo.lock | 3 +++ 1 file changed, 3 insertions(+) diff --git a/Cargo.lock b/Cargo.lock index 4c59bfea573b..b88de804c816 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3822,9 +3822,12 @@ dependencies = [ "log", "prisma-models", "psl", + "quaint", + "query-connector", "serde", "serde-wasm-bindgen", "serde_json", + "sql-query-connector", "thiserror", "tokio", "tracing", From 6648a882b4e2d0d8aa5449b2e94875cb807f2949 Mon Sep 17 00:00:00 2001 From: Alberto Schiabel Date: Tue, 14 Nov 2023 14:58:04 +0100 Subject: [PATCH 16/34] chore: remove leftover comments --- .../connectors/sql-query-connector/src/database/mod.rs | 5 ----- 1 file changed, 5 deletions(-) diff --git a/query-engine/connectors/sql-query-connector/src/database/mod.rs b/query-engine/connectors/sql-query-connector/src/database/mod.rs index 7172e0101400..e693769373b0 100644 --- a/query-engine/connectors/sql-query-connector/src/database/mod.rs +++ b/query-engine/connectors/sql-query-connector/src/database/mod.rs @@ -22,11 +22,6 @@ pub use js::*; #[cfg(not(target_arch = "wasm32"))] pub use native::{mssql::*, mysql::*, postgresql::*, sqlite::*}; -// pub use mssql::*; -// pub use mysql::*; -// pub use postgresql::*; -// pub use sqlite::*; - #[async_trait] pub trait FromSource { /// Instantiate a query connector from a Datasource. From 754746ecdaf0dadae3e44532bc268715ab3ce813 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 14 Nov 2023 16:38:30 +0100 Subject: [PATCH 17/34] feat(query-core): allow wasm32-unknown-unknown compilation --- Cargo.lock | 3 ++ .../query-tests-setup/Cargo.toml | 2 +- query-engine/core-tests/Cargo.toml | 2 +- query-engine/core/Cargo.toml | 11 +++- .../core/src/executor/execute_operation.rs | 11 ++++ query-engine/core/src/executor/mod.rs | 51 +++++++++++++++++++ .../interactive_transactions/actor_manager.rs | 2 +- .../src/interactive_transactions/actors.rs | 15 ++++-- query-engine/core/src/lib.rs | 7 ++- query-engine/query-engine-node-api/Cargo.toml | 2 +- query-engine/query-engine-wasm/Cargo.toml | 1 + query-engine/query-engine/Cargo.toml | 2 +- query-engine/request-handlers/Cargo.toml | 2 +- 13 files changed, 97 insertions(+), 14 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b88de804c816..50df863820fd 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3680,6 +3680,7 @@ dependencies = [ "once_cell", "opentelemetry", "petgraph 0.4.13", + "pin-project", "prisma-models", "psl", "query-connector", @@ -3695,6 +3696,7 @@ dependencies = [ "tracing-subscriber", "user-facing-errors", "uuid", + "wasm-bindgen-futures", ] [[package]] @@ -3824,6 +3826,7 @@ dependencies = [ "psl", "quaint", "query-connector", + "query-core", "serde", "serde-wasm-bindgen", "serde_json", diff --git a/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml b/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml index 088a0d4b2d34..f257d9e52162 100644 --- a/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml +++ b/query-engine/connector-test-kit-rs/query-tests-setup/Cargo.toml @@ -10,7 +10,7 @@ once_cell = "1" qe-setup = { path = "../qe-setup" } request-handlers = { path = "../../request-handlers" } tokio.workspace = true -query-core = { path = "../../core" } +query-core = { path = "../../core", features = ["metrics"] } sql-query-connector = { path = "../../connectors/sql-query-connector" } query-engine = { path = "../../query-engine"} psl.workspace = true diff --git a/query-engine/core-tests/Cargo.toml b/query-engine/core-tests/Cargo.toml index 9a2c3f5686eb..bac9219c3522 100644 --- a/query-engine/core-tests/Cargo.toml +++ b/query-engine/core-tests/Cargo.toml @@ -9,7 +9,7 @@ edition = "2021" dissimilar = "1.0.4" user-facing-errors = { path = "../../libs/user-facing-errors" } request-handlers = { path = "../request-handlers" } -query-core = { path = "../core" } +query-core = { path = "../core", features = ["metrics"] } schema = { path = "../schema" } psl.workspace = true serde_json.workspace = true diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index caadf6cdba00..6441abf8ca3a 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -3,6 +3,10 @@ edition = "2021" name = "query-core" version = "0.1.0" +[features] +# default = ["metrics"] +metrics = ["query-engine-metrics"] + [dependencies] async-trait = "0.1" bigdecimal = "0.3" @@ -18,11 +22,11 @@ once_cell = "1" petgraph = "0.4" prisma-models = { path = "../prisma-models", features = ["default_generators"] } opentelemetry = { version = "0.17.0", features = ["rt-tokio", "serialize"] } -query-engine-metrics = {path = "../metrics"} +query-engine-metrics = { path = "../metrics", optional = true } serde.workspace = true serde_json.workspace = true thiserror = "1.0" -tokio.workspace = true +tokio = { version = "1.0", features = ["macros", "time"] } tracing = { version = "0.1", features = ["attributes"] } tracing-futures = "0.2" tracing-subscriber = { version = "0.3", features = ["env-filter"] } @@ -34,3 +38,6 @@ schema = { path = "../schema" } lru = "0.7.7" enumflags2 = "0.7" +[target.'cfg(target_arch = "wasm32")'.dependencies] +pin-project = "1" +wasm-bindgen-futures = "0.4" diff --git a/query-engine/core/src/executor/execute_operation.rs b/query-engine/core/src/executor/execute_operation.rs index 06452fcdd865..6ba21d37f9ff 100644 --- a/query-engine/core/src/executor/execute_operation.rs +++ b/query-engine/core/src/executor/execute_operation.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_arch = "wasm32", allow(unused_variables))] + use super::pipeline::QueryPipeline; use crate::{ executor::request_context, protocol::EngineProtocol, CoreError, IrSerializer, Operation, QueryGraph, @@ -5,9 +7,12 @@ use crate::{ }; use connector::{Connection, ConnectionLike, Connector}; use futures::future; + +#[cfg(feature = "metrics")] use query_engine_metrics::{ histogram, increment_counter, metrics, PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, PRISMA_CLIENT_QUERIES_TOTAL, }; + use schema::{QuerySchema, QuerySchemaRef}; use std::time::{Duration, Instant}; use tracing::Instrument; @@ -24,6 +29,7 @@ pub async fn execute_single_operation( let (graph, serializer) = build_graph(&query_schema, operation.clone())?; let result = execute_on(conn, graph, serializer, query_schema.as_ref(), trace_id).await; + #[cfg(feature = "metrics")] histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, operation_timer.elapsed()); result @@ -45,6 +51,8 @@ pub async fn execute_many_operations( for (i, (graph, serializer)) in queries.into_iter().enumerate() { let operation_timer = Instant::now(); let result = execute_on(conn, graph, serializer, query_schema.as_ref(), trace_id.clone()).await; + + #[cfg(feature = "metrics")] histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, operation_timer.elapsed()); match result { @@ -98,6 +106,7 @@ pub async fn execute_many_self_contained( let dispatcher = crate::get_current_dispatcher(); for op in operations { + #[cfg(feature = "metrics")] increment_counter!(PRISMA_CLIENT_QUERIES_TOTAL); let conn_span = info_span!( @@ -158,6 +167,7 @@ async fn execute_self_contained( execute_self_contained_without_retry(conn, graph, serializer, force_transactions, &query_schema, trace_id).await }; + #[cfg(feature = "metrics")] histogram!(PRISMA_CLIENT_QUERIES_DURATION_HISTOGRAM_MS, operation_timer.elapsed()); result @@ -259,6 +269,7 @@ async fn execute_on<'a>( query_schema: &'a QuerySchema, trace_id: Option, ) -> crate::Result { + #[cfg(feature = "metrics")] increment_counter!(PRISMA_CLIENT_QUERIES_TOTAL); let interpreter = QueryInterpreter::new(conn); diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index ddbb7dfc8429..5ff9830013d6 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -12,6 +12,7 @@ mod pipeline; mod request_context; pub use self::{execute_operation::*, interpreting_executor::InterpretingExecutor}; +use futures::Future; pub(crate) use request_context::*; @@ -131,3 +132,53 @@ pub trait TransactionManager { pub fn get_current_dispatcher() -> Dispatch { tracing::dispatcher::get_default(|current| current.clone()) } + +#[cfg(not(target_arch = "wasm32"))] +pub(crate) mod task { + use super::*; + + pub type JoinHandle = tokio::task::JoinHandle; + + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + tokio::spawn(future) + } +} + +#[cfg(target_arch = "wasm32")] +pub(crate) mod task { + use super::*; + + #[pin_project::pin_project] + pub struct JoinHandle(#[pin] tokio::sync::oneshot::Receiver); + + impl Future for JoinHandle { + type Output = Result; + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + let this = self.project(); + this.0.poll(cx) + } + } + + impl JoinHandle { + pub fn abort(&mut self) { + // abort is noop for WASM builds + } + } + + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (tx, rx) = tokio::sync::oneshot::channel(); + wasm_bindgen_futures::spawn_local(async move { + let result = future.await; + tx.send(result).ok(); + }); + JoinHandle(rx) + } +} diff --git a/query-engine/core/src/interactive_transactions/actor_manager.rs b/query-engine/core/src/interactive_transactions/actor_manager.rs index 98208343d28a..105733be4166 100644 --- a/query-engine/core/src/interactive_transactions/actor_manager.rs +++ b/query-engine/core/src/interactive_transactions/actor_manager.rs @@ -1,3 +1,4 @@ +use crate::executor::task::JoinHandle; use crate::{protocol::EngineProtocol, ClosedTx, Operation, ResponseData}; use connector::Connection; use lru::LruCache; @@ -9,7 +10,6 @@ use tokio::{ mpsc::{channel, Sender}, RwLock, }, - task::JoinHandle, time::Duration, }; diff --git a/query-engine/core/src/interactive_transactions/actors.rs b/query-engine/core/src/interactive_transactions/actors.rs index 88402d86fedd..104ffc26812f 100644 --- a/query-engine/core/src/interactive_transactions/actors.rs +++ b/query-engine/core/src/interactive_transactions/actors.rs @@ -1,7 +1,8 @@ use super::{CachedTx, TransactionError, TxOpRequest, TxOpRequestMsg, TxOpResponse}; +use crate::executor::task::{spawn, JoinHandle}; use crate::{ - execute_many_operations, execute_single_operation, protocol::EngineProtocol, - telemetry::helpers::set_span_link_from_traceparent, ClosedTx, Operation, ResponseData, TxId, + execute_many_operations, execute_single_operation, protocol::EngineProtocol, ClosedTx, Operation, ResponseData, + TxId, }; use connector::Connection; use schema::QuerySchemaRef; @@ -11,13 +12,15 @@ use tokio::{ mpsc::{channel, Receiver, Sender}, oneshot, RwLock, }, - task::JoinHandle, time::{self, Duration, Instant}, }; use tracing::Span; use tracing_futures::Instrument; use tracing_futures::WithSubscriber; +#[cfg(feature = "metrics")] +use crate::telemetry::helpers::set_span_link_from_traceparent; + #[derive(PartialEq)] enum RunState { Continue, @@ -81,6 +84,8 @@ impl<'a> ITXServer<'a> { traceparent: Option, ) -> crate::Result { let span = info_span!("prisma:engine:itx_query_builder", user_facing = true); + + #[cfg(feature = "metrics")] set_span_link_from_traceparent(&span, traceparent.clone()); let conn = self.cached_tx.as_open()?; @@ -267,7 +272,7 @@ pub(crate) async fn spawn_itx_actor( }; let (open_transaction_send, open_transaction_rcv) = oneshot::channel(); - tokio::task::spawn( + spawn( crate::executor::with_request_context(engine_protocol, async move { // We match on the result in order to send the error to the parent task and abort this // task, on error. This is a separate task (actor), not a function where we can just bubble up the @@ -380,7 +385,7 @@ pub(crate) fn spawn_client_list_clear_actor( closed_txs: Arc>>>, mut rx: Receiver<(TxId, Option)>, ) -> JoinHandle<()> { - tokio::task::spawn(async move { + spawn(async move { loop { if let Some((id, closed_tx)) = rx.recv().await { trace!("removing {} from client list", id); diff --git a/query-engine/core/src/lib.rs b/query-engine/core/src/lib.rs index 7970c96139b7..38f39e9fb5d9 100644 --- a/query-engine/core/src/lib.rs +++ b/query-engine/core/src/lib.rs @@ -9,6 +9,8 @@ pub mod protocol; pub mod query_document; pub mod query_graph_builder; pub mod response_ir; + +#[cfg(feature = "metrics")] pub mod telemetry; pub use self::{ @@ -16,8 +18,11 @@ pub use self::{ executor::{QueryExecutor, TransactionOptions}, interactive_transactions::{ExtendedTransactionUserFacingError, TransactionError, TxId}, query_document::*, - telemetry::*, }; + +#[cfg(feature = "metrics")] +pub use self::telemetry::*; + pub use connector::{ error::{ConnectorError, ErrorKind as ConnectorErrorKind}, Connector, diff --git a/query-engine/query-engine-node-api/Cargo.toml b/query-engine/query-engine-node-api/Cargo.toml index 74f9686189fc..0eaed9eff7ce 100644 --- a/query-engine/query-engine-node-api/Cargo.toml +++ b/query-engine/query-engine-node-api/Cargo.toml @@ -16,7 +16,7 @@ driver-adapters = ["request-handlers/driver-adapters", "sql-connector/driver-ada [dependencies] anyhow = "1" async-trait = "0.1" -query-core = { path = "../core" } +query-core = { path = "../core", features = ["metrics"] } request-handlers = { path = "../request-handlers" } query-connector = { path = "../connectors/query-connector" } user-facing-errors = { path = "../../libs/user-facing-errors" } diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index f65f31c2d63b..c8bc6e2b5178 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -17,6 +17,7 @@ prisma-models = { path = "../prisma-models" } quaint = { path = "../../quaint" } connector = { path = "../connectors/query-connector", package = "query-connector" } sql-query-connector = { path = "../connectors/sql-query-connector" } +query-core = { path = "../core" } thiserror = "1" connection-string.workspace = true diff --git a/query-engine/query-engine/Cargo.toml b/query-engine/query-engine/Cargo.toml index be36e4f842dc..c70d8590d0ff 100644 --- a/query-engine/query-engine/Cargo.toml +++ b/query-engine/query-engine/Cargo.toml @@ -20,7 +20,7 @@ enumflags2 = { version = "0.7"} psl.workspace = true graphql-parser = { git = "https://github.com/prisma/graphql-parser" } mongodb-connector = { path = "../connectors/mongodb-query-connector", optional = true, package = "mongodb-query-connector" } -query-core = { path = "../core" } +query-core = { path = "../core", features = ["metrics"] } request-handlers = { path = "../request-handlers" } serde.workspace = true serde_json.workspace = true diff --git a/query-engine/request-handlers/Cargo.toml b/query-engine/request-handlers/Cargo.toml index f5fb433b13ba..e6545eda2234 100644 --- a/query-engine/request-handlers/Cargo.toml +++ b/query-engine/request-handlers/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" [dependencies] prisma-models = { path = "../prisma-models" } -query-core = { path = "../core" } +query-core = { path = "../core", features = ["metrics"] } user-facing-errors = { path = "../../libs/user-facing-errors" } psl.workspace = true dmmf_crate = { path = "../dmmf", package = "dmmf" } From fe2fb8bd412dfa7273e9cb140f515f30fc6c7072 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Tue, 14 Nov 2023 16:42:24 +0100 Subject: [PATCH 18/34] chore(sql-query-connector): fix clipppy on wasm32 --- .../connectors/sql-query-connector/src/database/connection.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/query-engine/connectors/sql-query-connector/src/database/connection.rs b/query-engine/connectors/sql-query-connector/src/database/connection.rs index 0247e8c4b601..7895e838399a 100644 --- a/query-engine/connectors/sql-query-connector/src/database/connection.rs +++ b/query-engine/connectors/sql-query-connector/src/database/connection.rs @@ -1,3 +1,5 @@ +#![cfg_attr(target_arch = "wasm32", allow(dead_code))] + use super::{catch, transaction::SqlConnectorTransaction}; use crate::{database::operations::*, Context, SqlError}; use async_trait::async_trait; From 9c41dc1fba3c560819740d5506fb075ac0310099 Mon Sep 17 00:00:00 2001 From: Alberto Schiabel Date: Tue, 14 Nov 2023 16:51:08 +0100 Subject: [PATCH 19/34] chore: remove leftover comment --- query-engine/core/Cargo.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index 6441abf8ca3a..7ccf1a293411 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -4,7 +4,6 @@ name = "query-core" version = "0.1.0" [features] -# default = ["metrics"] metrics = ["query-engine-metrics"] [dependencies] From b69bb840f0f58731a3ff9f4663dcd812a3937c81 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 11:00:17 +0100 Subject: [PATCH 20/34] feat(driver-adapters): enable Wasm on request-handlers --- query-engine/request-handlers/Cargo.toml | 9 +- .../request-handlers/src/connector_mode.rs | 1 + .../request-handlers/src/load_executor.rs | 162 +++++++++--------- 3 files changed, 90 insertions(+), 82 deletions(-) diff --git a/query-engine/request-handlers/Cargo.toml b/query-engine/request-handlers/Cargo.toml index e6545eda2234..f04d742c448e 100644 --- a/query-engine/request-handlers/Cargo.toml +++ b/query-engine/request-handlers/Cargo.toml @@ -5,8 +5,9 @@ edition = "2021" [dependencies] prisma-models = { path = "../prisma-models" } -query-core = { path = "../core", features = ["metrics"] } +query-core = { path = "../core" } user-facing-errors = { path = "../../libs/user-facing-errors" } +quaint = { path = "../../quaint" } psl.workspace = true dmmf_crate = { path = "../dmmf", package = "dmmf" } itertools = "0.10" @@ -20,7 +21,6 @@ thiserror = "1" tracing = "0.1" url = "2" connection-string.workspace = true -quaint.workspace = true once_cell = "1.15" mongodb-query-connector = { path = "../connectors/mongodb-query-connector", optional = true } @@ -32,10 +32,11 @@ schema = { path = "../schema" } codspeed-criterion-compat = "1.1.0" [features] -default = ["mongodb", "sql"] +default = ["sql", "mongodb", "native"] mongodb = ["mongodb-query-connector"] sql = ["sql-query-connector"] -driver-adapters = ["sql-query-connector"] +driver-adapters = ["sql-query-connector/driver-adapters"] +native = ["mongodb", "sql-query-connector", "quaint/native", "query-core/metrics"] [[bench]] name = "query_planning_bench" diff --git a/query-engine/request-handlers/src/connector_mode.rs b/query-engine/request-handlers/src/connector_mode.rs index 00e0515a596e..be03fbab5820 100644 --- a/query-engine/request-handlers/src/connector_mode.rs +++ b/query-engine/request-handlers/src/connector_mode.rs @@ -1,6 +1,7 @@ #[derive(Copy, Clone, PartialEq, Eq)] pub enum ConnectorMode { /// Indicates that Rust drivers are used in Query Engine. + #[cfg(feature = "native")] Rust, /// Indicates that JS drivers are used in Query Engine. diff --git a/query-engine/request-handlers/src/load_executor.rs b/query-engine/request-handlers/src/load_executor.rs index 652ad3108f0d..26728605f92a 100644 --- a/query-engine/request-handlers/src/load_executor.rs +++ b/query-engine/request-handlers/src/load_executor.rs @@ -1,14 +1,12 @@ +#![allow(unused_imports)] + use psl::{builtin_connectors::*, Datasource, PreviewFeatures}; use query_core::{executor::InterpretingExecutor, Connector, QueryExecutor}; use sql_query_connector::*; use std::collections::HashMap; use std::env; -use tracing::trace; use url::Url; -#[cfg(feature = "mongodb")] -use mongodb_query_connector::MongoDb; - use super::ConnectorMode; /// Loads a query executor based on the parsed Prisma schema (datasource). @@ -27,6 +25,7 @@ pub async fn load( driver_adapter(source, url, features).await } + #[cfg(feature = "native")] ConnectorMode::Rust => { if let Ok(value) = env::var("PRISMA_DISABLE_QUAINT_EXECUTORS") { let disable = value.to_uppercase(); @@ -36,14 +35,14 @@ pub async fn load( } match source.active_provider { - p if SQLITE.is_provider(p) => sqlite(source, url, features).await, - p if MYSQL.is_provider(p) => mysql(source, url, features).await, - p if POSTGRES.is_provider(p) => postgres(source, url, features).await, - p if MSSQL.is_provider(p) => mssql(source, url, features).await, - p if COCKROACH.is_provider(p) => postgres(source, url, features).await, + p if SQLITE.is_provider(p) => native::sqlite(source, url, features).await, + p if MYSQL.is_provider(p) => native::mysql(source, url, features).await, + p if POSTGRES.is_provider(p) => native::postgres(source, url, features).await, + p if MSSQL.is_provider(p) => native::mssql(source, url, features).await, + p if COCKROACH.is_provider(p) => native::postgres(source, url, features).await, #[cfg(feature = "mongodb")] - p if MONGODB.is_provider(p) => mongodb(source, url, features).await, + p if MONGODB.is_provider(p) => native::mongodb(source, url, features).await, x => Err(query_core::CoreError::ConfigurationError(format!( "Unsupported connector type: {x}" @@ -53,57 +52,88 @@ pub async fn load( } } -async fn sqlite( +#[cfg(feature = "driver-adapters")] +async fn driver_adapter( source: &Datasource, url: &str, features: PreviewFeatures, -) -> query_core::Result> { - trace!("Loading SQLite query connector..."); - let sqlite = Sqlite::from_source(source, url, features).await?; - trace!("Loaded SQLite query connector."); - Ok(executor_for(sqlite, false)) +) -> Result, query_core::CoreError> { + let js = Js::from_source(source, url, features).await?; + Ok(executor_for(js, false)) } -async fn postgres( - source: &Datasource, - url: &str, - features: PreviewFeatures, -) -> query_core::Result> { - trace!("Loading Postgres query connector..."); - let database_str = url; - let psql = PostgreSql::from_source(source, url, features).await?; - - let url = Url::parse(database_str) - .map_err(|err| query_core::CoreError::ConfigurationError(format!("Error parsing connection string: {err}")))?; - let params: HashMap = url.query_pairs().into_owned().collect(); - - let force_transactions = params - .get("pgbouncer") - .and_then(|flag| flag.parse().ok()) - .unwrap_or(false); - trace!("Loaded Postgres query connector."); - Ok(executor_for(psql, force_transactions)) -} +#[cfg(feature = "native")] +mod native { + use super::*; + use tracing::trace; + + pub(crate) async fn sqlite( + source: &Datasource, + url: &str, + features: PreviewFeatures, + ) -> query_core::Result> { + trace!("Loading SQLite query connector..."); + let sqlite = Sqlite::from_source(source, url, features).await?; + trace!("Loaded SQLite query connector."); + Ok(executor_for(sqlite, false)) + } -async fn mysql( - source: &Datasource, - url: &str, - features: PreviewFeatures, -) -> query_core::Result> { - let mysql = Mysql::from_source(source, url, features).await?; - trace!("Loaded MySQL query connector."); - Ok(executor_for(mysql, false)) -} + pub(crate) async fn postgres( + source: &Datasource, + url: &str, + features: PreviewFeatures, + ) -> query_core::Result> { + trace!("Loading Postgres query connector..."); + let database_str = url; + let psql = PostgreSql::from_source(source, url, features).await?; + + let url = Url::parse(database_str).map_err(|err| { + query_core::CoreError::ConfigurationError(format!("Error parsing connection string: {err}")) + })?; + let params: HashMap = url.query_pairs().into_owned().collect(); + + let force_transactions = params + .get("pgbouncer") + .and_then(|flag| flag.parse().ok()) + .unwrap_or(false); + trace!("Loaded Postgres query connector."); + Ok(executor_for(psql, force_transactions)) + } -async fn mssql( - source: &Datasource, - url: &str, - features: PreviewFeatures, -) -> query_core::Result> { - trace!("Loading SQL Server query connector..."); - let mssql = Mssql::from_source(source, url, features).await?; - trace!("Loaded SQL Server query connector."); - Ok(executor_for(mssql, false)) + pub(crate) async fn mysql( + source: &Datasource, + url: &str, + features: PreviewFeatures, + ) -> query_core::Result> { + let mysql = Mysql::from_source(source, url, features).await?; + trace!("Loaded MySQL query connector."); + Ok(executor_for(mysql, false)) + } + + pub(crate) async fn mssql( + source: &Datasource, + url: &str, + features: PreviewFeatures, + ) -> query_core::Result> { + trace!("Loading SQL Server query connector..."); + let mssql = Mssql::from_source(source, url, features).await?; + trace!("Loaded SQL Server query connector."); + Ok(executor_for(mssql, false)) + } + + #[cfg(feature = "mongodb")] + pub(crate) async fn mongodb( + source: &Datasource, + url: &str, + _features: PreviewFeatures, + ) -> query_core::Result> { + use mongodb_query_connector::MongoDb; + + trace!("Loading MongoDB query connector..."); + let mongo = MongoDb::new(source, url).await?; + trace!("Loaded MongoDB query connector."); + Ok(executor_for(mongo, false)) + } } fn executor_for(connector: T, force_transactions: bool) -> Box @@ -112,27 +142,3 @@ where { Box::new(InterpretingExecutor::new(connector, force_transactions)) } - -#[cfg(feature = "mongodb")] -async fn mongodb( - source: &Datasource, - url: &str, - _features: PreviewFeatures, -) -> query_core::Result> { - trace!("Loading MongoDB query connector..."); - let mongo = MongoDb::new(source, url).await?; - trace!("Loaded MongoDB query connector."); - Ok(executor_for(mongo, false)) -} - -#[cfg(feature = "driver-adapters")] -async fn driver_adapter( - source: &Datasource, - url: &str, - features: PreviewFeatures, -) -> Result, query_core::CoreError> { - trace!("Loading driver adapter..."); - let js = Js::from_source(source, url, features).await?; - trace!("Loaded driver adapter..."); - Ok(executor_for(js, false)) -} From c987dceb3895fa57f0e16fa84d72d217fd186673 Mon Sep 17 00:00:00 2001 From: Miguel Fernandez Date: Wed, 15 Nov 2023 12:51:00 +0100 Subject: [PATCH 21/34] WIP: refactor mysql module to flatten its structure --- quaint/src/connector.rs | 4 ++-- quaint/src/connector/mysql.rs | 11 +++++++---- quaint/src/connector/mysql/{wasm => }/error.rs | 0 quaint/src/connector/mysql/native/error.rs | 2 +- quaint/src/connector/mysql/native/mod.rs | 2 +- quaint/src/connector/mysql/{wasm/common.rs => url.rs} | 0 quaint/src/connector/mysql/wasm/mod.rs | 6 ------ 7 files changed, 11 insertions(+), 14 deletions(-) rename quaint/src/connector/mysql/{wasm => }/error.rs (100%) rename quaint/src/connector/mysql/{wasm/common.rs => url.rs} (100%) delete mode 100644 quaint/src/connector/mysql/wasm/mod.rs diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index 7903d23931c0..a2ee455fee22 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -38,10 +38,10 @@ pub use postgres::wasm::common::*; #[cfg(feature = "mysql")] pub(crate) mod mysql; +#[cfg(feature = "mysql")] +pub use mysql::*; #[cfg(feature = "mysql-native")] pub use mysql::native::*; -#[cfg(feature = "mysql")] -pub use mysql::wasm::common::*; #[cfg(feature = "sqlite")] pub(crate) mod sqlite; diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index 1e52af6a83a0..0834be88949e 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -1,8 +1,11 @@ -pub use wasm::common::MysqlUrl; -pub use wasm::error::MysqlError; +//! Wasm-compatible definitions for the MySQL connector. +//! This module is only available with the `mysql` feature. +pub mod error; +pub(crate) mod url; -#[cfg(feature = "mysql")] -pub(crate) mod wasm; +pub use error::MysqlError; +pub use url::MysqlUrl; #[cfg(feature = "mysql-native")] pub(crate) mod native; + diff --git a/quaint/src/connector/mysql/wasm/error.rs b/quaint/src/connector/mysql/error.rs similarity index 100% rename from quaint/src/connector/mysql/wasm/error.rs rename to quaint/src/connector/mysql/error.rs diff --git a/quaint/src/connector/mysql/native/error.rs b/quaint/src/connector/mysql/native/error.rs index e00ff1e0aa74..89c21fb706f6 100644 --- a/quaint/src/connector/mysql/native/error.rs +++ b/quaint/src/connector/mysql/native/error.rs @@ -1,5 +1,5 @@ use crate::{ - connector::mysql::wasm::error::MysqlError, + connector::mysql::error::MysqlError, error::{Error, ErrorKind}, }; use mysql_async as my; diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index e72a2c47a9a1..7a95ee47b614 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -4,7 +4,7 @@ mod conversion; mod error; -pub(crate) use crate::connector::mysql::wasm::common::MysqlUrl; +pub(crate) use crate::connector::mysql::MysqlUrl; use crate::connector::{timeout, IsolationLevel}; use crate::{ diff --git a/quaint/src/connector/mysql/wasm/common.rs b/quaint/src/connector/mysql/url.rs similarity index 100% rename from quaint/src/connector/mysql/wasm/common.rs rename to quaint/src/connector/mysql/url.rs diff --git a/quaint/src/connector/mysql/wasm/mod.rs b/quaint/src/connector/mysql/wasm/mod.rs deleted file mode 100644 index 4f73f82031d5..000000000000 --- a/quaint/src/connector/mysql/wasm/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -//! Wasm-compatible definitions for the MySQL connector. -//! This module is only available with the `mysql` feature. -pub(crate) mod common; -pub mod error; - -pub use common::MysqlUrl; From 626bc1ef904d0e46fec7046a64cb3927889b6452 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 13:17:23 +0100 Subject: [PATCH 22/34] feat(quaint): flatten mssql connector module --- quaint/src/connector.rs | 6 +++--- quaint/src/connector/mssql.rs | 7 ++++--- quaint/src/connector/mssql/native/mod.rs | 2 +- quaint/src/connector/mssql/{wasm/common.rs => url.rs} | 0 quaint/src/connector/mssql/wasm/mod.rs | 5 ----- quaint/src/connector/mysql.rs | 1 - 6 files changed, 8 insertions(+), 13 deletions(-) rename quaint/src/connector/mssql/{wasm/common.rs => url.rs} (100%) delete mode 100644 quaint/src/connector/mssql/wasm/mod.rs diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index a2ee455fee22..97643978228b 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -38,10 +38,10 @@ pub use postgres::wasm::common::*; #[cfg(feature = "mysql")] pub(crate) mod mysql; -#[cfg(feature = "mysql")] -pub use mysql::*; #[cfg(feature = "mysql-native")] pub use mysql::native::*; +#[cfg(feature = "mysql")] +pub use mysql::*; #[cfg(feature = "sqlite")] pub(crate) mod sqlite; @@ -55,4 +55,4 @@ pub(crate) mod mssql; #[cfg(feature = "mssql-native")] pub use mssql::native::*; #[cfg(feature = "mssql")] -pub use mssql::wasm::common::*; +pub use mssql::*; diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index c83b5f1f7266..09f589192676 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -1,7 +1,8 @@ -pub use wasm::common::MssqlUrl; +//! Wasm-compatible definitions for the MSSQL connector. +//! This module is only available with the `mssql` feature. +pub(crate) mod url; -#[cfg(feature = "mssql")] -pub(crate) mod wasm; +pub use url::MssqlUrl; #[cfg(feature = "mssql-native")] pub(crate) mod native; diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index 8458935814b4..d7052d5e5180 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -4,7 +4,7 @@ mod conversion; mod error; -pub(crate) use crate::connector::mssql::wasm::common::MssqlUrl; +pub(crate) use crate::connector::mssql::MssqlUrl; use crate::connector::{timeout, IsolationLevel, Transaction, TransactionOptions}; use crate::{ diff --git a/quaint/src/connector/mssql/wasm/common.rs b/quaint/src/connector/mssql/url.rs similarity index 100% rename from quaint/src/connector/mssql/wasm/common.rs rename to quaint/src/connector/mssql/url.rs diff --git a/quaint/src/connector/mssql/wasm/mod.rs b/quaint/src/connector/mssql/wasm/mod.rs deleted file mode 100644 index 5a25a32836c2..000000000000 --- a/quaint/src/connector/mssql/wasm/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -//! Wasm-compatible definitions for the MSSQL connector. -//! This module is only available with the `mssql` feature. -pub(crate) mod common; - -pub use common::MssqlUrl; diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index 0834be88949e..5ca2c3551f29 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -8,4 +8,3 @@ pub use url::MysqlUrl; #[cfg(feature = "mysql-native")] pub(crate) mod native; - From a9f8ba841de6f1715b6e1002f67d22a3b60c5c6d Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 13:23:19 +0100 Subject: [PATCH 23/34] feat(quaint): flatten postgres connector module --- quaint/src/connector.rs | 2 +- quaint/src/connector/mysql.rs | 2 +- quaint/src/connector/postgres.rs | 10 ++++++---- quaint/src/connector/postgres/{wasm => }/error.rs | 0 quaint/src/connector/postgres/native/error.rs | 2 +- quaint/src/connector/postgres/native/mod.rs | 6 +++--- .../src/connector/postgres/{wasm/common.rs => url.rs} | 0 quaint/src/connector/postgres/wasm/mod.rs | 6 ------ 8 files changed, 12 insertions(+), 16 deletions(-) rename quaint/src/connector/postgres/{wasm => }/error.rs (100%) rename quaint/src/connector/postgres/{wasm/common.rs => url.rs} (100%) delete mode 100644 quaint/src/connector/postgres/wasm/mod.rs diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index 97643978228b..82b1437b6c03 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -34,7 +34,7 @@ pub(crate) mod postgres; #[cfg(feature = "postgresql-native")] pub use postgres::native::*; #[cfg(feature = "postgresql")] -pub use postgres::wasm::common::*; +pub use postgres::*; #[cfg(feature = "mysql")] pub(crate) mod mysql; diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index 5ca2c3551f29..23fed3c70bd3 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -1,6 +1,6 @@ //! Wasm-compatible definitions for the MySQL connector. //! This module is only available with the `mysql` feature. -pub mod error; +pub(crate) mod error; pub(crate) mod url; pub use error::MysqlError; diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 73a8547b8a65..71d40e71ba0f 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -1,8 +1,10 @@ -pub use wasm::common::PostgresUrl; -pub use wasm::error::PostgresError; +//! Wasm-compatible definitions for the PostgreSQL connector. +//! This module is only available with the `postgresql` feature. +pub(crate) mod error; +pub(crate) mod url; -#[cfg(feature = "postgresql")] -pub(crate) mod wasm; +pub use error::PostgresError; +pub use url::{PostgresFlavour, PostgresUrl}; #[cfg(feature = "postgresql-native")] pub(crate) mod native; diff --git a/quaint/src/connector/postgres/wasm/error.rs b/quaint/src/connector/postgres/error.rs similarity index 100% rename from quaint/src/connector/postgres/wasm/error.rs rename to quaint/src/connector/postgres/error.rs diff --git a/quaint/src/connector/postgres/native/error.rs b/quaint/src/connector/postgres/native/error.rs index 05b792e27900..c353e397705c 100644 --- a/quaint/src/connector/postgres/native/error.rs +++ b/quaint/src/connector/postgres/native/error.rs @@ -1,7 +1,7 @@ use tokio_postgres::error::DbError; use crate::{ - connector::postgres::wasm::error::PostgresError, + connector::postgres::error::PostgresError, error::{Error, ErrorKind}, }; diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index fbb4760ed19f..5dbf67a91cdf 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -4,8 +4,8 @@ mod conversion; mod error; -pub(crate) use crate::connector::postgres::wasm::common::PostgresUrl; -use crate::connector::postgres::wasm::common::{Hidden, SslAcceptMode, SslParams}; +pub(crate) use crate::connector::postgres::url::PostgresUrl; +use crate::connector::postgres::url::{Hidden, SslAcceptMode, SslParams}; use crate::connector::{timeout, IsolationLevel, Transaction}; use crate::{ @@ -670,7 +670,7 @@ fn is_safe_identifier(ident: &str) -> bool { #[cfg(test)] mod tests { use super::*; - pub(crate) use crate::connector::postgres::wasm::common::PostgresFlavour; + pub(crate) use crate::connector::postgres::url::PostgresFlavour; use crate::tests::test_api::postgres::CONN_STR; use crate::tests::test_api::CRDB_CONN_STR; use crate::{connector::Queryable, error::*, single::Quaint}; diff --git a/quaint/src/connector/postgres/wasm/common.rs b/quaint/src/connector/postgres/url.rs similarity index 100% rename from quaint/src/connector/postgres/wasm/common.rs rename to quaint/src/connector/postgres/url.rs diff --git a/quaint/src/connector/postgres/wasm/mod.rs b/quaint/src/connector/postgres/wasm/mod.rs deleted file mode 100644 index 859de8f6fd3c..000000000000 --- a/quaint/src/connector/postgres/wasm/mod.rs +++ /dev/null @@ -1,6 +0,0 @@ -//! Wasm-compatible definitions for the Postgres connector. -//! This module is only available with the `postgresql` feature. -pub(crate) mod common; -pub mod error; - -pub use common::PostgresUrl; From 3c1a1008c915f1baae29c39dbc82f55fb6b0945a Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 13:28:58 +0100 Subject: [PATCH 24/34] feat(quaint): flatten sqlite connector module --- quaint/src/connector.rs | 2 +- quaint/src/connector/sqlite.rs | 10 +++++++--- quaint/src/connector/sqlite/{wasm => }/error.rs | 0 quaint/src/connector/sqlite/{wasm => }/ffi.rs | 0 quaint/src/connector/sqlite/native/error.rs | 2 +- quaint/src/connector/sqlite/native/mod.rs | 2 +- .../src/connector/sqlite/{wasm/common.rs => params.rs} | 0 quaint/src/connector/sqlite/wasm/mod.rs | 5 ----- 8 files changed, 10 insertions(+), 11 deletions(-) rename quaint/src/connector/sqlite/{wasm => }/error.rs (100%) rename quaint/src/connector/sqlite/{wasm => }/ffi.rs (100%) rename quaint/src/connector/sqlite/{wasm/common.rs => params.rs} (100%) delete mode 100644 quaint/src/connector/sqlite/wasm/mod.rs diff --git a/quaint/src/connector.rs b/quaint/src/connector.rs index 82b1437b6c03..dddb3c953ad7 100644 --- a/quaint/src/connector.rs +++ b/quaint/src/connector.rs @@ -48,7 +48,7 @@ pub(crate) mod sqlite; #[cfg(feature = "sqlite-native")] pub use sqlite::native::*; #[cfg(feature = "sqlite")] -pub use sqlite::wasm::common::*; +pub use sqlite::*; #[cfg(feature = "mssql")] pub(crate) mod mssql; diff --git a/quaint/src/connector/sqlite.rs b/quaint/src/connector/sqlite.rs index 45611aab9357..c59c947b8dc1 100644 --- a/quaint/src/connector/sqlite.rs +++ b/quaint/src/connector/sqlite.rs @@ -1,7 +1,11 @@ -pub use wasm::error::SqliteError; +//! Wasm-compatible definitions for the SQLite connector. +//! This module is only available with the `sqlite` feature. +pub(crate) mod error; +mod ffi; +pub(crate) mod params; -#[cfg(feature = "sqlite")] -pub(crate) mod wasm; +pub use error::SqliteError; +pub use params::*; #[cfg(feature = "sqlite-native")] pub(crate) mod native; diff --git a/quaint/src/connector/sqlite/wasm/error.rs b/quaint/src/connector/sqlite/error.rs similarity index 100% rename from quaint/src/connector/sqlite/wasm/error.rs rename to quaint/src/connector/sqlite/error.rs diff --git a/quaint/src/connector/sqlite/wasm/ffi.rs b/quaint/src/connector/sqlite/ffi.rs similarity index 100% rename from quaint/src/connector/sqlite/wasm/ffi.rs rename to quaint/src/connector/sqlite/ffi.rs diff --git a/quaint/src/connector/sqlite/native/error.rs b/quaint/src/connector/sqlite/native/error.rs index d09e2959ce28..51b2417ed821 100644 --- a/quaint/src/connector/sqlite/native/error.rs +++ b/quaint/src/connector/sqlite/native/error.rs @@ -1,4 +1,4 @@ -use crate::connector::sqlite::wasm::error::SqliteError; +use crate::connector::sqlite::error::SqliteError; use crate::error::*; diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index bdf5c473fd4d..4b686f5968d6 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -4,7 +4,7 @@ mod conversion; mod error; -use crate::connector::sqlite::wasm::common::SqliteParams; +use crate::connector::sqlite::params::SqliteParams; use crate::connector::IsolationLevel; pub use rusqlite::{params_from_iter, version as sqlite_version}; diff --git a/quaint/src/connector/sqlite/wasm/common.rs b/quaint/src/connector/sqlite/params.rs similarity index 100% rename from quaint/src/connector/sqlite/wasm/common.rs rename to quaint/src/connector/sqlite/params.rs diff --git a/quaint/src/connector/sqlite/wasm/mod.rs b/quaint/src/connector/sqlite/wasm/mod.rs deleted file mode 100644 index 662237af30a1..000000000000 --- a/quaint/src/connector/sqlite/wasm/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -//! Wasm-compatible definitions for the SQLite connector. -//! This module is only available with the `sqlite` feature. -pub(crate) mod common; -pub mod error; -mod ffi; From 7f4c8f943142d45340dbd2c4c621093998130a72 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 13:30:47 +0100 Subject: [PATCH 25/34] chore(quaint): export all public definitions in connector "url" modules --- quaint/src/connector/mssql.rs | 2 +- quaint/src/connector/mysql.rs | 2 +- quaint/src/connector/postgres.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/quaint/src/connector/mssql.rs b/quaint/src/connector/mssql.rs index 09f589192676..e18b68fb2ce1 100644 --- a/quaint/src/connector/mssql.rs +++ b/quaint/src/connector/mssql.rs @@ -2,7 +2,7 @@ //! This module is only available with the `mssql` feature. pub(crate) mod url; -pub use url::MssqlUrl; +pub use url::*; #[cfg(feature = "mssql-native")] pub(crate) mod native; diff --git a/quaint/src/connector/mysql.rs b/quaint/src/connector/mysql.rs index 23fed3c70bd3..0dc504dd2d11 100644 --- a/quaint/src/connector/mysql.rs +++ b/quaint/src/connector/mysql.rs @@ -4,7 +4,7 @@ pub(crate) mod error; pub(crate) mod url; pub use error::MysqlError; -pub use url::MysqlUrl; +pub use url::*; #[cfg(feature = "mysql-native")] pub(crate) mod native; diff --git a/quaint/src/connector/postgres.rs b/quaint/src/connector/postgres.rs index 71d40e71ba0f..d1694108a1b7 100644 --- a/quaint/src/connector/postgres.rs +++ b/quaint/src/connector/postgres.rs @@ -4,7 +4,7 @@ pub(crate) mod error; pub(crate) mod url; pub use error::PostgresError; -pub use url::{PostgresFlavour, PostgresUrl}; +pub use url::*; #[cfg(feature = "postgresql-native")] pub(crate) mod native; From 95a4e28c89a1cda3c7bdf6f15b7ae543ae5da780 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 17:06:22 +0100 Subject: [PATCH 26/34] chore(quaint): refactor tests for connectors, addressing feedback --- quaint/src/connector/mssql/native/mod.rs | 17 -- quaint/src/connector/mysql/native/mod.rs | 83 -------- quaint/src/connector/mysql/url.rs | 83 ++++++++ quaint/src/connector/postgres/native/mod.rs | 215 +------------------ quaint/src/connector/postgres/url.rs | 223 ++++++++++++++++++++ quaint/src/connector/sqlite/native/mod.rs | 21 -- quaint/src/connector/sqlite/params.rs | 26 +++ 7 files changed, 333 insertions(+), 335 deletions(-) diff --git a/quaint/src/connector/mssql/native/mod.rs b/quaint/src/connector/mssql/native/mod.rs index d7052d5e5180..d22aa7a15dd6 100644 --- a/quaint/src/connector/mssql/native/mod.rs +++ b/quaint/src/connector/mssql/native/mod.rs @@ -237,20 +237,3 @@ impl Queryable for Mssql { true } } - -#[cfg(test)] -mod tests { - use crate::tests::test_api::mssql::CONN_STR; - use crate::{error::*, single::Quaint}; - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let url = CONN_STR.replace("user=SA", "user=WRONG"); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } -} diff --git a/quaint/src/connector/mysql/native/mod.rs b/quaint/src/connector/mysql/native/mod.rs index 7a95ee47b614..fdcc3a6276d1 100644 --- a/quaint/src/connector/mysql/native/mod.rs +++ b/quaint/src/connector/mysql/native/mod.rs @@ -295,86 +295,3 @@ impl Queryable for Mysql { true } } - -#[cfg(test)] -mod tests { - use super::MysqlUrl; - use crate::tests::test_api::mysql::CONN_STR; - use crate::{error::*, single::Quaint}; - use url::Url; - - #[test] - fn should_parse_socket_url() { - let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket()); - } - - #[test] - fn should_parse_prefer_socket() { - let url = - MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap(); - assert!(!url.prefer_socket().unwrap()); - } - - #[test] - fn should_parse_sslaccept() { - let url = - MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap(); - assert!(url.query_params.use_ssl); - assert!(!url.query_params.ssl_opts.skip_domain_validation()); - assert!(!url.query_params.ssl_opts.accept_invalid_certs()); - } - - #[test] - fn should_parse_ipv6_host() { - let url = MysqlUrl::new(Url::parse("mysql://[2001:db8:1234::ffff]:5432/testdb").unwrap()).unwrap(); - assert_eq!("2001:db8:1234::ffff", url.host()); - } - - #[test] - fn should_allow_changing_of_cache_size() { - let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap()) - .unwrap(); - assert_eq!(420, url.cache().capacity()); - } - - #[test] - fn should_have_default_cache_size() { - let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap(); - assert_eq!(100, url.cache().capacity()); - } - - #[tokio::test] - async fn should_map_nonexisting_database_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("root").unwrap(); - url.set_path("/this_does_not_exist"); - - let url = url.as_str().to_string(); - let res = Quaint::new(&url).await; - - let err = res.unwrap_err(); - - match err.kind() { - ErrorKind::DatabaseDoesNotExist { db_name } => { - assert_eq!(Some("1049"), err.original_code()); - assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message()); - assert_eq!(&Name::available("this_does_not_exist"), db_name) - } - e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e), - } - } - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("WRONG").unwrap(); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } -} diff --git a/quaint/src/connector/mysql/url.rs b/quaint/src/connector/mysql/url.rs index c17b2224c0ef..f0756fa95833 100644 --- a/quaint/src/connector/mysql/url.rs +++ b/quaint/src/connector/mysql/url.rs @@ -316,3 +316,86 @@ pub(crate) struct MysqlUrlQueryParams { #[cfg(feature = "mysql-native")] pub(crate) ssl_opts: mysql_async::SslOpts, } + +#[cfg(test)] +mod tests { + use super::MysqlUrl; + use crate::tests::test_api::mysql::CONN_STR; + use crate::{error::*, single::Quaint}; + use url::Url; + + #[test] + fn should_parse_socket_url() { + let url = MysqlUrl::new(Url::parse("mysql://root@localhost/dbname?socket=(/tmp/mysql.sock)").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!(&Some(String::from("/tmp/mysql.sock")), url.socket()); + } + + #[test] + fn should_parse_prefer_socket() { + let url = + MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?prefer_socket=false").unwrap()).unwrap(); + assert!(!url.prefer_socket().unwrap()); + } + + #[test] + fn should_parse_sslaccept() { + let url = + MysqlUrl::new(Url::parse("mysql://root:root@localhost:3307/testdb?sslaccept=strict").unwrap()).unwrap(); + assert!(url.query_params.use_ssl); + assert!(!url.query_params.ssl_opts.skip_domain_validation()); + assert!(!url.query_params.ssl_opts.accept_invalid_certs()); + } + + #[test] + fn should_parse_ipv6_host() { + let url = MysqlUrl::new(Url::parse("mysql://[2001:db8:1234::ffff]:5432/testdb").unwrap()).unwrap(); + assert_eq!("2001:db8:1234::ffff", url.host()); + } + + #[test] + fn should_allow_changing_of_cache_size() { + let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo?statement_cache_size=420").unwrap()) + .unwrap(); + assert_eq!(420, url.cache().capacity()); + } + + #[test] + fn should_have_default_cache_size() { + let url = MysqlUrl::new(Url::parse("mysql:///root:root@localhost:3307/foo").unwrap()).unwrap(); + assert_eq!(100, url.cache().capacity()); + } + + #[tokio::test] + async fn should_map_nonexisting_database_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("root").unwrap(); + url.set_path("/this_does_not_exist"); + + let url = url.as_str().to_string(); + let res = Quaint::new(&url).await; + + let err = res.unwrap_err(); + + match err.kind() { + ErrorKind::DatabaseDoesNotExist { db_name } => { + assert_eq!(Some("1049"), err.original_code()); + assert_eq!(Some("Unknown database \'this_does_not_exist\'"), err.original_message()); + assert_eq!(&Name::available("this_does_not_exist"), db_name) + } + e => panic!("Expected `DatabaseDoesNotExist`, got {:?}", e), + } + } + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("WRONG").unwrap(); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } +} diff --git a/quaint/src/connector/postgres/native/mod.rs b/quaint/src/connector/postgres/native/mod.rs index 5dbf67a91cdf..30f34e7002be 100644 --- a/quaint/src/connector/postgres/native/mod.rs +++ b/quaint/src/connector/postgres/native/mod.rs @@ -671,89 +671,11 @@ fn is_safe_identifier(ident: &str) -> bool { mod tests { use super::*; pub(crate) use crate::connector::postgres::url::PostgresFlavour; + use crate::connector::Queryable; use crate::tests::test_api::postgres::CONN_STR; use crate::tests::test_api::CRDB_CONN_STR; - use crate::{connector::Queryable, error::*, single::Quaint}; use url::Url; - #[test] - fn should_parse_socket_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("/var/run/psql.sock", url.host()); - } - - #[test] - fn should_parse_escaped_url() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("/var/run/postgresql", url.host()); - } - - #[test] - fn should_allow_changing_of_cache_size() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); - assert_eq!(420, url.cache().capacity()); - } - - #[test] - fn should_have_default_cache_size() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); - assert_eq!(100, url.cache().capacity()); - } - - #[test] - fn should_have_application_name() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); - assert_eq!(Some("test"), url.application_name()); - } - - #[test] - fn should_have_channel_binding() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Require, url.channel_binding()); - } - - #[test] - fn should_have_default_channel_binding() { - let url = - PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); - assert_eq!(ChannelBinding::Prefer, url.channel_binding()); - } - - #[test] - fn should_not_enable_caching_with_pgbouncer() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); - assert_eq!(0, url.cache().capacity()); - } - - #[test] - fn should_parse_default_host() { - let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); - assert_eq!("dbname", url.dbname()); - assert_eq!("localhost", url.host()); - } - - #[test] - fn should_parse_ipv6_host() { - let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); - assert_eq!("2001:db8:1234::ffff", url.host()); - } - - #[test] - fn should_handle_options_field() { - let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) - .unwrap(); - - assert_eq!("--cluster=my_cluster", url.options().unwrap()); - } - #[tokio::test] async fn test_custom_search_path_pg() { async fn test_path(schema_name: &str) -> Option { @@ -1010,82 +932,6 @@ mod tests { } } - #[tokio::test] - async fn should_map_nonexisting_database_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_path("/this_does_not_exist"); - - let res = Quaint::new(url.as_str()).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::DatabaseDoesNotExist { db_name } => { - assert_eq!(Some("3D000"), e.original_code()); - assert_eq!( - Some("database \"this_does_not_exist\" does not exist"), - e.original_message() - ); - assert_eq!(&Name::available("this_does_not_exist"), db_name) - } - kind => panic!("Expected `DatabaseDoesNotExist`, got {:?}", kind), - }, - } - } - - #[tokio::test] - async fn should_map_wrong_credentials_error() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.set_username("WRONG").unwrap(); - - let res = Quaint::new(url.as_str()).await; - assert!(res.is_err()); - - let err = res.unwrap_err(); - assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); - } - - #[tokio::test] - async fn should_map_tls_errors() { - let mut url = Url::parse(&CONN_STR).expect("parsing url"); - url.set_query(Some("sslmode=require&sslaccept=strict")); - - let res = Quaint::new(url.as_str()).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::TlsError { .. } => (), - other => panic!("{:#?}", other), - }, - } - } - - #[tokio::test] - async fn should_map_incorrect_parameters_error() { - let url = Url::parse(&CONN_STR).unwrap(); - let conn = Quaint::new(url.as_str()).await.unwrap(); - - let res = conn.query_raw("SELECT $1", &[Value::int32(1), Value::int32(2)]).await; - - assert!(res.is_err()); - - match res { - Ok(_) => unreachable!(), - Err(e) => match e.kind() { - ErrorKind::IncorrectNumberOfParameters { expected, actual } => { - assert_eq!(1, *expected); - assert_eq!(2, *actual); - } - other => panic!("{:#?}", other), - }, - } - } - #[test] fn test_safe_ident() { // Safe @@ -1123,63 +969,4 @@ mod tests { assert!(!is_safe_identifier(ident)); } } - - #[test] - fn search_path_pgbouncer_should_be_set_with_query() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - url.query_pairs_mut().append_pair("pgbouncer", "true"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let config = pg_url.to_config(); - - // PGBouncer does not support the `search_path` connection parameter. - // When `pgbouncer=true`, config.search_path should be None, - // And the `search_path` should be set via a db query after connection. - assert_eq!(config.get_search_path(), None); - } - - #[test] - fn search_path_pg_should_be_set_with_param() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Postgres); - - let config = pg_url.to_config(); - - // Postgres supports setting the search_path via a connection parameter. - assert_eq!(config.get_search_path(), Some(&"\"hello\"".to_owned())); - } - - #[test] - fn search_path_crdb_safe_ident_should_be_set_with_param() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "hello"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let config = pg_url.to_config(); - - // CRDB supports setting the search_path via a connection parameter if the identifier is safe. - assert_eq!(config.get_search_path(), Some(&"hello".to_owned())); - } - - #[test] - fn search_path_crdb_unsafe_ident_should_be_set_with_query() { - let mut url = Url::parse(&CONN_STR).unwrap(); - url.query_pairs_mut().append_pair("schema", "HeLLo"); - - let mut pg_url = PostgresUrl::new(url).unwrap(); - pg_url.set_flavour(PostgresFlavour::Cockroach); - - let config = pg_url.to_config(); - - // CRDB does NOT support setting the search_path via a connection parameter if the identifier is unsafe. - assert_eq!(config.get_search_path(), None); - } } diff --git a/quaint/src/connector/postgres/url.rs b/quaint/src/connector/postgres/url.rs index 7b9b3aafabb4..f0b60d88a848 100644 --- a/quaint/src/connector/postgres/url.rs +++ b/quaint/src/connector/postgres/url.rs @@ -470,3 +470,226 @@ impl Display for SetSearchPath<'_> { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + use crate::ast::Value; + pub(crate) use crate::connector::postgres::url::PostgresFlavour; + use crate::tests::test_api::postgres::CONN_STR; + use crate::{connector::Queryable, error::*, single::Quaint}; + use url::Url; + + #[test] + fn should_parse_socket_url() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=/var/run/psql.sock").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("/var/run/psql.sock", url.host()); + } + + #[test] + fn should_parse_escaped_url() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname?host=%2Fvar%2Frun%2Fpostgresql").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("/var/run/postgresql", url.host()); + } + + #[test] + fn should_allow_changing_of_cache_size() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?statement_cache_size=420").unwrap()).unwrap(); + assert_eq!(420, url.cache().capacity()); + } + + #[test] + fn should_have_default_cache_size() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + assert_eq!(100, url.cache().capacity()); + } + + #[test] + fn should_have_application_name() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?application_name=test").unwrap()).unwrap(); + assert_eq!(Some("test"), url.application_name()); + } + + #[test] + fn should_have_channel_binding() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=require").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Require, url.channel_binding()); + } + + #[test] + fn should_have_default_channel_binding() { + let url = + PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?channel_binding=invalid").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Prefer, url.channel_binding()); + + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo").unwrap()).unwrap(); + assert_eq!(ChannelBinding::Prefer, url.channel_binding()); + } + + #[test] + fn should_not_enable_caching_with_pgbouncer() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432/foo?pgbouncer=true").unwrap()).unwrap(); + assert_eq!(0, url.cache().capacity()); + } + + #[test] + fn should_parse_default_host() { + let url = PostgresUrl::new(Url::parse("postgresql:///dbname").unwrap()).unwrap(); + assert_eq!("dbname", url.dbname()); + assert_eq!("localhost", url.host()); + } + + #[test] + fn should_parse_ipv6_host() { + let url = PostgresUrl::new(Url::parse("postgresql://[2001:db8:1234::ffff]:5432/dbname").unwrap()).unwrap(); + assert_eq!("2001:db8:1234::ffff", url.host()); + } + + #[test] + fn should_handle_options_field() { + let url = PostgresUrl::new(Url::parse("postgresql:///localhost:5432?options=--cluster%3Dmy_cluster").unwrap()) + .unwrap(); + + assert_eq!("--cluster=my_cluster", url.options().unwrap()); + } + + #[tokio::test] + async fn should_map_nonexisting_database_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_path("/this_does_not_exist"); + + let res = Quaint::new(url.as_str()).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::DatabaseDoesNotExist { db_name } => { + assert_eq!(Some("3D000"), e.original_code()); + assert_eq!( + Some("database \"this_does_not_exist\" does not exist"), + e.original_message() + ); + assert_eq!(&Name::available("this_does_not_exist"), db_name) + } + kind => panic!("Expected `DatabaseDoesNotExist`, got {:?}", kind), + }, + } + } + + #[tokio::test] + async fn should_map_wrong_credentials_error() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.set_username("WRONG").unwrap(); + + let res = Quaint::new(url.as_str()).await; + assert!(res.is_err()); + + let err = res.unwrap_err(); + assert!(matches!(err.kind(), ErrorKind::AuthenticationFailed { user } if user == &Name::available("WRONG"))); + } + + #[tokio::test] + async fn should_map_tls_errors() { + let mut url = Url::parse(&CONN_STR).expect("parsing url"); + url.set_query(Some("sslmode=require&sslaccept=strict")); + + let res = Quaint::new(url.as_str()).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::TlsError { .. } => (), + other => panic!("{:#?}", other), + }, + } + } + + #[tokio::test] + async fn should_map_incorrect_parameters_error() { + let url = Url::parse(&CONN_STR).unwrap(); + let conn = Quaint::new(url.as_str()).await.unwrap(); + + let res = conn.query_raw("SELECT $1", &[Value::int32(1), Value::int32(2)]).await; + + assert!(res.is_err()); + + match res { + Ok(_) => unreachable!(), + Err(e) => match e.kind() { + ErrorKind::IncorrectNumberOfParameters { expected, actual } => { + assert_eq!(1, *expected); + assert_eq!(2, *actual); + } + other => panic!("{:#?}", other), + }, + } + } + + #[test] + fn search_path_pgbouncer_should_be_set_with_query() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + url.query_pairs_mut().append_pair("pgbouncer", "true"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let config = pg_url.to_config(); + + // PGBouncer does not support the `search_path` connection parameter. + // When `pgbouncer=true`, config.search_path should be None, + // And the `search_path` should be set via a db query after connection. + assert_eq!(config.get_search_path(), None); + } + + #[test] + fn search_path_pg_should_be_set_with_param() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Postgres); + + let config = pg_url.to_config(); + + // Postgres supports setting the search_path via a connection parameter. + assert_eq!(config.get_search_path(), Some(&"\"hello\"".to_owned())); + } + + #[test] + fn search_path_crdb_safe_ident_should_be_set_with_param() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "hello"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let config = pg_url.to_config(); + + // CRDB supports setting the search_path via a connection parameter if the identifier is safe. + assert_eq!(config.get_search_path(), Some(&"hello".to_owned())); + } + + #[test] + fn search_path_crdb_unsafe_ident_should_be_set_with_query() { + let mut url = Url::parse(&CONN_STR).unwrap(); + url.query_pairs_mut().append_pair("schema", "HeLLo"); + + let mut pg_url = PostgresUrl::new(url).unwrap(); + pg_url.set_flavour(PostgresFlavour::Cockroach); + + let config = pg_url.to_config(); + + // CRDB does NOT support setting the search_path via a connection parameter if the identifier is unsafe. + assert_eq!(config.get_search_path(), None); + } +} diff --git a/quaint/src/connector/sqlite/native/mod.rs b/quaint/src/connector/sqlite/native/mod.rs index 4b686f5968d6..3bf0c46a7db5 100644 --- a/quaint/src/connector/sqlite/native/mod.rs +++ b/quaint/src/connector/sqlite/native/mod.rs @@ -165,27 +165,6 @@ mod tests { error::{ErrorKind, Name}, }; - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_file_scheme() { - let path = "file:dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_sqlite_scheme() { - let path = "sqlite:dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - - #[test] - fn sqlite_params_from_str_should_resolve_path_correctly_with_no_scheme() { - let path = "dev.db"; - let params = SqliteParams::try_from(path).unwrap(); - assert_eq!(params.file_path, "dev.db"); - } - #[tokio::test] async fn unknown_table_should_give_a_good_error() { let conn = Sqlite::try_from("file:db/test.db").unwrap(); diff --git a/quaint/src/connector/sqlite/params.rs b/quaint/src/connector/sqlite/params.rs index 46fb5c08f669..f024aa97a694 100644 --- a/quaint/src/connector/sqlite/params.rs +++ b/quaint/src/connector/sqlite/params.rs @@ -103,3 +103,29 @@ impl TryFrom<&str> for SqliteParams { } } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_file_scheme() { + let path = "file:dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_sqlite_scheme() { + let path = "sqlite:dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } + + #[test] + fn sqlite_params_from_str_should_resolve_path_correctly_with_no_scheme() { + let path = "dev.db"; + let params = SqliteParams::try_from(path).unwrap(); + assert_eq!(params.file_path, "dev.db"); + } +} From bacb635367bb994939c9cdcf530033d334a3224b Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 17:19:27 +0100 Subject: [PATCH 27/34] chore: add comment on MysqlAsyncError --- quaint/src/connector/mysql/error.rs | 2 ++ 1 file changed, 2 insertions(+) diff --git a/quaint/src/connector/mysql/error.rs b/quaint/src/connector/mysql/error.rs index 615f0c69dda4..7b4813bf0223 100644 --- a/quaint/src/connector/mysql/error.rs +++ b/quaint/src/connector/mysql/error.rs @@ -1,6 +1,8 @@ use crate::error::{DatabaseConstraint, Error, ErrorKind}; use thiserror::Error; +// This is a partial copy of the `mysql_async::Error` using only the enum variant used by Prisma. +// This avoids pulling in `mysql_async`, which would break Wasm compilation. #[derive(Debug, Error)] enum MysqlAsyncError { #[error("Server error: `{}'", _0)] From 263bab0c84a396cf5867dd65911c2af79cad7824 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 17:22:52 +0100 Subject: [PATCH 28/34] chore: add comment on ffi.rs for sqlite --- quaint/src/connector/sqlite/ffi.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/quaint/src/connector/sqlite/ffi.rs b/quaint/src/connector/sqlite/ffi.rs index bddfd4354237..c510a459be81 100644 --- a/quaint/src/connector/sqlite/ffi.rs +++ b/quaint/src/connector/sqlite/ffi.rs @@ -1,4 +1,5 @@ -//! This is a partial copy of `rusqlite::ffi::*`. +//! Here, we export only the constants we need to avoid pulling in `rusqlite::ffi::*`, in the sibling `error.rs` file, +//! which would break Wasm compilation. pub const SQLITE_BUSY: i32 = 5; pub const SQLITE_CONSTRAINT_FOREIGNKEY: i32 = 787; pub const SQLITE_CONSTRAINT_NOTNULL: i32 = 1299; From 76816fdc30780b0b026d6db1307ff8e0ade77510 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 17:52:05 +0100 Subject: [PATCH 29/34] chore: replace awkward "super::super::" with "crate::..." --- .../sql-query-connector/src/database/native/mssql.rs | 4 ++-- .../sql-query-connector/src/database/native/mysql.rs | 4 ++-- .../sql-query-connector/src/database/native/postgresql.rs | 4 ++-- .../sql-query-connector/src/database/native/sqlite.rs | 4 ++-- .../sql_schema_calculator_flavour/mssql.rs | 2 +- .../sql_schema_calculator_flavour/postgres.rs | 2 +- 6 files changed, 10 insertions(+), 10 deletions(-) diff --git a/query-engine/connectors/sql-query-connector/src/database/native/mssql.rs b/query-engine/connectors/sql-query-connector/src/database/native/mssql.rs index bdb6e2ee103c..19d3580bba9f 100644 --- a/query-engine/connectors/sql-query-connector/src/database/native/mssql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/mssql.rs @@ -1,4 +1,4 @@ -use super::super::connection::SqlConnection; +use crate::database::{catch, connection::SqlConnection}; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -60,7 +60,7 @@ impl FromSource for Mssql { #[async_trait] impl Connector for Mssql { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::super::catch(self.connection_info.clone(), async move { + catch(self.connection_info.clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, &self.connection_info, self.features); diff --git a/query-engine/connectors/sql-query-connector/src/database/native/mysql.rs b/query-engine/connectors/sql-query-connector/src/database/native/mysql.rs index a1cd585c0005..477d687b995b 100644 --- a/query-engine/connectors/sql-query-connector/src/database/native/mysql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/mysql.rs @@ -1,4 +1,4 @@ -use super::super::connection::SqlConnection; +use crate::database::{catch, connection::SqlConnection}; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -65,7 +65,7 @@ impl FromSource for Mysql { #[async_trait] impl Connector for Mysql { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::super::catch(self.connection_info.clone(), async move { + catch(self.connection_info.clone(), async move { let runtime_conn = self.pool.check_out().await?; // Note: `runtime_conn` must be `Sized`, as that's required by `TransactionCapable` diff --git a/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs index 80025add046f..0e49a1de8bbd 100644 --- a/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/postgresql.rs @@ -1,4 +1,4 @@ -use super::super::connection::SqlConnection; +use crate::database::{catch, connection::SqlConnection}; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -67,7 +67,7 @@ impl FromSource for PostgreSql { #[async_trait] impl Connector for PostgreSql { async fn get_connection<'a>(&'a self) -> connector_interface::Result> { - super::super::catch(self.connection_info.clone(), async move { + catch(self.connection_info.clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, &self.connection_info, self.features); Ok(Box::new(conn) as Box) diff --git a/query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs b/query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs index b1250b18b2be..e38bccb861f4 100644 --- a/query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs +++ b/query-engine/connectors/sql-query-connector/src/database/native/sqlite.rs @@ -1,4 +1,4 @@ -use super::super::connection::SqlConnection; +use crate::database::{catch, connection::SqlConnection}; use crate::{FromSource, SqlError}; use async_trait::async_trait; use connector_interface::{ @@ -80,7 +80,7 @@ fn invalid_file_path_error(file_path: &str, connection_info: &ConnectionInfo) -> #[async_trait] impl Connector for Sqlite { async fn get_connection<'a>(&'a self) -> connector::Result> { - super::super::catch(self.connection_info().clone(), async move { + catch(self.connection_info().clone(), async move { let conn = self.pool.check_out().await.map_err(SqlError::from)?; let conn = SqlConnection::new(conn, self.connection_info(), self.features); diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/mssql.rs b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/mssql.rs index 18a0b8e94b3c..51a8f5ef54be 100644 --- a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/mssql.rs +++ b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/mssql.rs @@ -23,7 +23,7 @@ impl SqlSchemaCalculatorFlavour for MssqlFlavour { } } - fn push_connector_data(&self, context: &mut super::super::Context<'_>) { + fn push_connector_data(&self, context: &mut crate::sql_schema_calculator::Context<'_>) { let mut data = MssqlSchemaExt::default(); for model in context.datamodel.db.walk_models() { diff --git a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/postgres.rs b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/postgres.rs index 40577d68a35d..656fe432a970 100644 --- a/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/postgres.rs +++ b/schema-engine/connectors/sql-schema-connector/src/sql_schema_calculator/sql_schema_calculator_flavour/postgres.rs @@ -37,7 +37,7 @@ impl SqlSchemaCalculatorFlavour for PostgresFlavour { } } - fn push_connector_data(&self, context: &mut super::super::Context<'_>) { + fn push_connector_data(&self, context: &mut crate::sql_schema_calculator::Context<'_>) { let mut postgres_ext = PostgresSchemaExt::default(); let db = &context.datamodel.db; From 5126a75cbe880cfe799b4706a8272304cb0473b2 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 20:53:54 +0100 Subject: [PATCH 30/34] chore: add comments around "query_core::executor::task" --- query-engine/core/Cargo.toml | 3 + query-engine/core/src/executor/mod.rs | 94 +++++++++++++++------------ 2 files changed, 56 insertions(+), 41 deletions(-) diff --git a/query-engine/core/Cargo.toml b/query-engine/core/Cargo.toml index 7ccf1a293411..9e0f03517cb5 100644 --- a/query-engine/core/Cargo.toml +++ b/query-engine/core/Cargo.toml @@ -37,6 +37,9 @@ schema = { path = "../schema" } lru = "0.7.7" enumflags2 = "0.7" +pin-project = "1" +wasm-bindgen-futures = "0.4" + [target.'cfg(target_arch = "wasm32")'.dependencies] pin-project = "1" wasm-bindgen-futures = "0.4" diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index 5ff9830013d6..43df839e9635 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -12,7 +12,6 @@ mod pipeline; mod request_context; pub use self::{execute_operation::*, interpreting_executor::InterpretingExecutor}; -use futures::Future; pub(crate) use request_context::*; @@ -133,52 +132,65 @@ pub fn get_current_dispatcher() -> Dispatch { tracing::dispatcher::get_default(|current| current.clone()) } -#[cfg(not(target_arch = "wasm32"))] +// The `task` module provides a unified interface for spawning asynchronous tasks, regardless of the target platform. pub(crate) mod task { - use super::*; - - pub type JoinHandle = tokio::task::JoinHandle; - - pub fn spawn(future: T) -> JoinHandle - where - T: Future + Send + 'static, - T::Output: Send + 'static, - { - tokio::spawn(future) + pub use arch::{spawn, JoinHandle}; + use futures::Future; + + // On native targets, `tokio::spawn` spawns a new asynchronous task. + #[cfg(not(target_arch = "wasm32"))] + mod arch { + use super::*; + + pub type JoinHandle = tokio::task::JoinHandle; + + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + tokio::spawn(future) + } } -} - -#[cfg(target_arch = "wasm32")] -pub(crate) mod task { - use super::*; - - #[pin_project::pin_project] - pub struct JoinHandle(#[pin] tokio::sync::oneshot::Receiver); - impl Future for JoinHandle { - type Output = Result; - fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { - let this = self.project(); - this.0.poll(cx) + // On Wasm targets, `wasm_bindgen_futures::spawn_local` spawns a new asynchronous task. + #[cfg(target_arch = "wasm32")] + mod arch { + use super::*; + use tokio::sync::oneshot::{self}; + + // Wasm-compatible alternative to `tokio::task::JoinHandle`. + // `pin_project` enables pin-projection and a `Pin`-compatible implementation of the `Future` trait. + #[pin_project::pin_project] + pub struct JoinHandle(#[pin] oneshot::Receiver); + + impl Future for JoinHandle { + type Output = Result; + + fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + // the `self.project()` method is provided by the `pin_project` macro + let receiver: std::pin::Pin<&mut oneshot::Receiver> = self.project().0; + receiver.poll(cx) + } } - } - impl JoinHandle { - pub fn abort(&mut self) { - // abort is noop for WASM builds + impl JoinHandle { + pub fn abort(&mut self) { + // abort is noop on Wasm targets + } } - } - pub fn spawn(future: T) -> JoinHandle - where - T: Future + Send + 'static, - T::Output: Send + 'static, - { - let (tx, rx) = tokio::sync::oneshot::channel(); - wasm_bindgen_futures::spawn_local(async move { - let result = future.await; - tx.send(result).ok(); - }); - JoinHandle(rx) + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (sender, receiver) = oneshot::channel(); + wasm_bindgen_futures::spawn_local(async move { + let result = future.await; + sender.send(result).ok(); + }); + JoinHandle(receiver) + } } } From de39d9e6a46ba8e7612d39989e009d635512bb3c Mon Sep 17 00:00:00 2001 From: jkomyno Date: Wed, 15 Nov 2023 21:02:12 +0100 Subject: [PATCH 31/34] chore: add "request-handlers" to "query-engine-wasm" --- Cargo.lock | 1 + query-engine/query-engine-wasm/Cargo.toml | 3 ++- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 50df863820fd..8544b8ae8134 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3827,6 +3827,7 @@ dependencies = [ "quaint", "query-connector", "query-core", + "request-handlers", "serde", "serde-wasm-bindgen", "serde_json", diff --git a/query-engine/query-engine-wasm/Cargo.toml b/query-engine/query-engine-wasm/Cargo.toml index c8bc6e2b5178..07757fde5d0a 100644 --- a/query-engine/query-engine-wasm/Cargo.toml +++ b/query-engine/query-engine-wasm/Cargo.toml @@ -15,9 +15,10 @@ user-facing-errors = { path = "../../libs/user-facing-errors" } psl.workspace = true prisma-models = { path = "../prisma-models" } quaint = { path = "../../quaint" } -connector = { path = "../connectors/query-connector", package = "query-connector" } +query-connector = { path = "../connectors/query-connector" } sql-query-connector = { path = "../connectors/sql-query-connector" } query-core = { path = "../core" } +request-handlers = { path = "../request-handlers", default-features = false, features = ["sql", "driver-adapters"] } thiserror = "1" connection-string.workspace = true From 2339b313786c7401ef46b55c174b53ecbff13332 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 17 Nov 2023 16:19:23 +0100 Subject: [PATCH 32/34] chore: move "task" module into its own file --- query-engine/core/src/executor/mod.rs | 64 +------------------------- query-engine/core/src/executor/task.rs | 59 ++++++++++++++++++++++++ 2 files changed, 60 insertions(+), 63 deletions(-) create mode 100644 query-engine/core/src/executor/task.rs diff --git a/query-engine/core/src/executor/mod.rs b/query-engine/core/src/executor/mod.rs index 43df839e9635..ba2784d3c71a 100644 --- a/query-engine/core/src/executor/mod.rs +++ b/query-engine/core/src/executor/mod.rs @@ -10,6 +10,7 @@ mod execute_operation; mod interpreting_executor; mod pipeline; mod request_context; +pub(crate) mod task; pub use self::{execute_operation::*, interpreting_executor::InterpretingExecutor}; @@ -131,66 +132,3 @@ pub trait TransactionManager { pub fn get_current_dispatcher() -> Dispatch { tracing::dispatcher::get_default(|current| current.clone()) } - -// The `task` module provides a unified interface for spawning asynchronous tasks, regardless of the target platform. -pub(crate) mod task { - pub use arch::{spawn, JoinHandle}; - use futures::Future; - - // On native targets, `tokio::spawn` spawns a new asynchronous task. - #[cfg(not(target_arch = "wasm32"))] - mod arch { - use super::*; - - pub type JoinHandle = tokio::task::JoinHandle; - - pub fn spawn(future: T) -> JoinHandle - where - T: Future + Send + 'static, - T::Output: Send + 'static, - { - tokio::spawn(future) - } - } - - // On Wasm targets, `wasm_bindgen_futures::spawn_local` spawns a new asynchronous task. - #[cfg(target_arch = "wasm32")] - mod arch { - use super::*; - use tokio::sync::oneshot::{self}; - - // Wasm-compatible alternative to `tokio::task::JoinHandle`. - // `pin_project` enables pin-projection and a `Pin`-compatible implementation of the `Future` trait. - #[pin_project::pin_project] - pub struct JoinHandle(#[pin] oneshot::Receiver); - - impl Future for JoinHandle { - type Output = Result; - - fn poll(self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { - // the `self.project()` method is provided by the `pin_project` macro - let receiver: std::pin::Pin<&mut oneshot::Receiver> = self.project().0; - receiver.poll(cx) - } - } - - impl JoinHandle { - pub fn abort(&mut self) { - // abort is noop on Wasm targets - } - } - - pub fn spawn(future: T) -> JoinHandle - where - T: Future + Send + 'static, - T::Output: Send + 'static, - { - let (sender, receiver) = oneshot::channel(); - wasm_bindgen_futures::spawn_local(async move { - let result = future.await; - sender.send(result).ok(); - }); - JoinHandle(receiver) - } - } -} diff --git a/query-engine/core/src/executor/task.rs b/query-engine/core/src/executor/task.rs new file mode 100644 index 000000000000..8d1c39bbcd06 --- /dev/null +++ b/query-engine/core/src/executor/task.rs @@ -0,0 +1,59 @@ +//! This module provides a unified interface for spawning asynchronous tasks, regardless of the target platform. + +pub use arch::{spawn, JoinHandle}; +use futures::Future; + +// On native targets, `tokio::spawn` spawns a new asynchronous task. +#[cfg(not(target_arch = "wasm32"))] +mod arch { + use super::*; + + pub type JoinHandle = tokio::task::JoinHandle; + + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + tokio::spawn(future) + } +} + +// On Wasm targets, `wasm_bindgen_futures::spawn_local` spawns a new asynchronous task. +#[cfg(target_arch = "wasm32")] +mod arch { + use super::*; + use tokio::sync::oneshot::{self}; + + // Wasm-compatible alternative to `tokio::task::JoinHandle`. + // `pin_project` enables pin-projection and a `Pin`-compatible implementation of the `Future` trait. + pub struct JoinHandle(oneshot::Receiver); + + impl Future for JoinHandle { + type Output = Result; + + fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll { + // the `self.project()` method is provided by the `pin_project` macro + core::pin::Pin::new(&mut self.0).poll(cx) + } + } + + impl JoinHandle { + pub fn abort(&mut self) { + // abort is noop on Wasm targets + } + } + + pub fn spawn(future: T) -> JoinHandle + where + T: Future + Send + 'static, + T::Output: Send + 'static, + { + let (sender, receiver) = oneshot::channel(); + wasm_bindgen_futures::spawn_local(async move { + let result = future.await; + sender.send(result).ok(); + }); + JoinHandle(receiver) + } +} From 96cd8ca800176eaeb9cf803c2b7e340973eaddc0 Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 17 Nov 2023 16:30:14 +0100 Subject: [PATCH 33/34] fix(driver-adapters): ci for "request-handlers" --- query-engine/connectors/sql-query-connector/Cargo.toml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/query-engine/connectors/sql-query-connector/Cargo.toml b/query-engine/connectors/sql-query-connector/Cargo.toml index fa9c32ef88e1..9ed0b4070056 100644 --- a/query-engine/connectors/sql-query-connector/Cargo.toml +++ b/query-engine/connectors/sql-query-connector/Cargo.toml @@ -26,9 +26,14 @@ tracing-futures = "0.2" uuid.workspace = true opentelemetry = { version = "0.17", features = ["tokio"] } tracing-opentelemetry = "0.17.3" -quaint = { path = "../../../quaint" } cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +quaint.workspace = true + +[target.'cfg(target_arch = "wasm32")'.dependencies] +quaint = { path = "../../../quaint" } + [dependencies.connector-interface] package = "query-connector" path = "../query-connector" From 3541054b2723c17a158e5e1990ca581461756f1b Mon Sep 17 00:00:00 2001 From: jkomyno Date: Fri, 17 Nov 2023 16:30:14 +0100 Subject: [PATCH 34/34] fix(driver-adapters): ci for "request-handlers" --- query-engine/connectors/sql-query-connector/Cargo.toml | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/query-engine/connectors/sql-query-connector/Cargo.toml b/query-engine/connectors/sql-query-connector/Cargo.toml index fa9c32ef88e1..9ed0b4070056 100644 --- a/query-engine/connectors/sql-query-connector/Cargo.toml +++ b/query-engine/connectors/sql-query-connector/Cargo.toml @@ -26,9 +26,14 @@ tracing-futures = "0.2" uuid.workspace = true opentelemetry = { version = "0.17", features = ["tokio"] } tracing-opentelemetry = "0.17.3" -quaint = { path = "../../../quaint" } cuid = { git = "https://github.com/prisma/cuid-rust", branch = "wasm32-support" } +[target.'cfg(not(target_arch = "wasm32"))'.dependencies] +quaint.workspace = true + +[target.'cfg(target_arch = "wasm32")'.dependencies] +quaint = { path = "../../../quaint" } + [dependencies.connector-interface] package = "query-connector" path = "../query-connector"