From 5849e9fd3b1e2efbccda61c63faab7843bc62ede Mon Sep 17 00:00:00 2001 From: Caleb Schoepp Date: Mon, 15 Jul 2024 15:04:48 -0600 Subject: [PATCH] factors: Add more tests to factor-outbound-pg and refactor it to be generic across pg impl Signed-off-by: Caleb Schoepp --- Cargo.lock | 1 - crates/factor-outbound-pg/Cargo.toml | 1 - crates/factor-outbound-pg/src/client.rs | 284 ++++++++++++++++++ crates/factor-outbound-pg/src/host.rs | 284 +----------------- crates/factor-outbound-pg/src/lib.rs | 32 +- .../factor-outbound-pg/tests/factor_test.rs | 123 +++++++- 6 files changed, 434 insertions(+), 291 deletions(-) create mode 100644 crates/factor-outbound-pg/src/client.rs diff --git a/Cargo.lock b/Cargo.lock index 891e4802c..f2cae7b54 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7615,7 +7615,6 @@ dependencies = [ "spin-core", "spin-factor-outbound-networking", "spin-factor-variables", - "spin-factor-wasi", "spin-factors", "spin-factors-test", "spin-world", diff --git a/crates/factor-outbound-pg/Cargo.toml b/crates/factor-outbound-pg/Cargo.toml index ca18e93a1..cd8681a4a 100644 --- a/crates/factor-outbound-pg/Cargo.toml +++ b/crates/factor-outbound-pg/Cargo.toml @@ -19,7 +19,6 @@ tracing = { workspace = true } [dev-dependencies] spin-factor-variables = { path = "../factor-variables" } -spin-factor-wasi = { path = "../factor-wasi" } spin-factors-test = { path = "../factors-test" } tokio = { version = "1", features = ["macros", "rt"] } diff --git a/crates/factor-outbound-pg/src/client.rs b/crates/factor-outbound-pg/src/client.rs new file mode 100644 index 000000000..06a93a631 --- /dev/null +++ b/crates/factor-outbound-pg/src/client.rs @@ -0,0 +1,284 @@ +use anyhow::{anyhow, Result}; +use native_tls::TlsConnector; +use postgres_native_tls::MakeTlsConnector; +use spin_world::async_trait; +use spin_world::v2::postgres::{self as v2}; +use spin_world::v2::rdbms_types::{Column, DbDataType, DbValue, ParameterValue, RowSet}; +use tokio_postgres::types::Type; +use tokio_postgres::{config::SslMode, types::ToSql, Row}; +use tokio_postgres::{Client as TokioClient, NoTls, Socket}; + +#[async_trait] +pub trait Client { + async fn build_client(address: &str) -> Result + where + Self: Sized; + + async fn execute( + &self, + statement: String, + params: Vec, + ) -> Result; + + async fn query( + &self, + statement: String, + params: Vec, + ) -> Result; +} + +#[async_trait] +impl Client for TokioClient { + async fn build_client(address: &str) -> Result + where + Self: Sized, + { + let config = address.parse::()?; + + tracing::debug!("Build new connection: {}", address); + + if config.get_ssl_mode() == SslMode::Disable { + let (client, connection) = config.connect(NoTls).await?; + spawn_connection(connection); + Ok(client) + } else { + let builder = TlsConnector::builder(); + let connector = MakeTlsConnector::new(builder.build()?); + let (client, connection) = config.connect(connector).await?; + spawn_connection(connection); + Ok(client) + } + } + + async fn execute( + &self, + statement: String, + params: Vec, + ) -> Result { + let params: Vec<&(dyn ToSql + Sync)> = params + .iter() + .map(to_sql_parameter) + .collect::>>() + .map_err(|e| v2::Error::ValueConversionFailed(format!("{:?}", e)))?; + + self.execute(&statement, params.as_slice()) + .await + .map_err(|e| v2::Error::QueryFailed(format!("{:?}", e))) + } + + async fn query( + &self, + statement: String, + params: Vec, + ) -> Result { + let params: Vec<&(dyn ToSql + Sync)> = params + .iter() + .map(to_sql_parameter) + .collect::>>() + .map_err(|e| v2::Error::BadParameter(format!("{:?}", e)))?; + + let results = self + .query(&statement, params.as_slice()) + .await + .map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?; + + if results.is_empty() { + return Ok(RowSet { + columns: vec![], + rows: vec![], + }); + } + + let columns = infer_columns(&results[0]); + let rows = results + .iter() + .map(convert_row) + .collect::, _>>() + .map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?; + + Ok(RowSet { columns, rows }) + } +} + +fn spawn_connection(connection: tokio_postgres::Connection) +where + T: tokio_postgres::tls::TlsStream + std::marker::Unpin + std::marker::Send + 'static, +{ + tokio::spawn(async move { + if let Err(e) = connection.await { + tracing::error!("Postgres connection error: {}", e); + } + }); +} + +fn to_sql_parameter(value: &ParameterValue) -> Result<&(dyn ToSql + Sync)> { + match value { + ParameterValue::Boolean(v) => Ok(v), + ParameterValue::Int32(v) => Ok(v), + ParameterValue::Int64(v) => Ok(v), + ParameterValue::Int8(v) => Ok(v), + ParameterValue::Int16(v) => Ok(v), + ParameterValue::Floating32(v) => Ok(v), + ParameterValue::Floating64(v) => Ok(v), + ParameterValue::Uint8(_) + | ParameterValue::Uint16(_) + | ParameterValue::Uint32(_) + | ParameterValue::Uint64(_) => Err(anyhow!("Postgres does not support unsigned integers")), + ParameterValue::Str(v) => Ok(v), + ParameterValue::Binary(v) => Ok(v), + ParameterValue::DbNull => Ok(&PgNull), + } +} + +fn infer_columns(row: &Row) -> Vec { + let mut result = Vec::with_capacity(row.len()); + for index in 0..row.len() { + result.push(infer_column(row, index)); + } + result +} + +fn infer_column(row: &Row, index: usize) -> Column { + let column = &row.columns()[index]; + let name = column.name().to_owned(); + let data_type = convert_data_type(column.type_()); + Column { name, data_type } +} + +fn convert_data_type(pg_type: &Type) -> DbDataType { + match *pg_type { + Type::BOOL => DbDataType::Boolean, + Type::BYTEA => DbDataType::Binary, + Type::FLOAT4 => DbDataType::Floating32, + Type::FLOAT8 => DbDataType::Floating64, + Type::INT2 => DbDataType::Int16, + Type::INT4 => DbDataType::Int32, + Type::INT8 => DbDataType::Int64, + Type::TEXT | Type::VARCHAR | Type::BPCHAR => DbDataType::Str, + _ => { + tracing::debug!("Couldn't convert Postgres type {} to WIT", pg_type.name(),); + DbDataType::Other + } + } +} + +fn convert_row(row: &Row) -> Result, tokio_postgres::Error> { + let mut result = Vec::with_capacity(row.len()); + for index in 0..row.len() { + result.push(convert_entry(row, index)?); + } + Ok(result) +} + +fn convert_entry(row: &Row, index: usize) -> Result { + let column = &row.columns()[index]; + let value = match column.type_() { + &Type::BOOL => { + let value: Option = row.try_get(index)?; + match value { + Some(v) => DbValue::Boolean(v), + None => DbValue::DbNull, + } + } + &Type::BYTEA => { + let value: Option> = row.try_get(index)?; + match value { + Some(v) => DbValue::Binary(v), + None => DbValue::DbNull, + } + } + &Type::FLOAT4 => { + let value: Option = row.try_get(index)?; + match value { + Some(v) => DbValue::Floating32(v), + None => DbValue::DbNull, + } + } + &Type::FLOAT8 => { + let value: Option = row.try_get(index)?; + match value { + Some(v) => DbValue::Floating64(v), + None => DbValue::DbNull, + } + } + &Type::INT2 => { + let value: Option = row.try_get(index)?; + match value { + Some(v) => DbValue::Int16(v), + None => DbValue::DbNull, + } + } + &Type::INT4 => { + let value: Option = row.try_get(index)?; + match value { + Some(v) => DbValue::Int32(v), + None => DbValue::DbNull, + } + } + &Type::INT8 => { + let value: Option = row.try_get(index)?; + match value { + Some(v) => DbValue::Int64(v), + None => DbValue::DbNull, + } + } + &Type::TEXT | &Type::VARCHAR | &Type::BPCHAR => { + let value: Option = row.try_get(index)?; + match value { + Some(v) => DbValue::Str(v), + None => DbValue::DbNull, + } + } + t => { + tracing::debug!( + "Couldn't convert Postgres type {} in column {}", + t.name(), + column.name() + ); + DbValue::Unsupported + } + }; + Ok(value) +} + +/// Although the Postgres crate converts Rust Option::None to Postgres NULL, +/// it enforces the type of the Option as it does so. (For example, trying to +/// pass an Option::::None to a VARCHAR column fails conversion.) As we +/// do not know expected column types, we instead use a "neutral" custom type +/// which allows conversion to any type but always tells the Postgres crate to +/// treat it as a SQL NULL. +struct PgNull; + +impl ToSql for PgNull { + fn to_sql( + &self, + _ty: &Type, + _out: &mut tokio_postgres::types::private::BytesMut, + ) -> Result> + where + Self: Sized, + { + Ok(tokio_postgres::types::IsNull::Yes) + } + + fn accepts(_ty: &Type) -> bool + where + Self: Sized, + { + true + } + + fn to_sql_checked( + &self, + _ty: &Type, + _out: &mut tokio_postgres::types::private::BytesMut, + ) -> Result> { + Ok(tokio_postgres::types::IsNull::Yes) + } +} + +impl std::fmt::Debug for PgNull { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("NULL").finish() + } +} diff --git a/crates/factor-outbound-pg/src/host.rs b/crates/factor-outbound-pg/src/host.rs index 63bc9ac91..1f7be3570 100644 --- a/crates/factor-outbound-pg/src/host.rs +++ b/crates/factor-outbound-pg/src/host.rs @@ -1,27 +1,21 @@ -use anyhow::{anyhow, Result}; -use native_tls::TlsConnector; -use postgres_native_tls::MakeTlsConnector; +use anyhow::Result; use spin_core::{async_trait, wasmtime::component::Resource}; use spin_world::v1::postgres as v1; use spin_world::v1::rdbms_types as v1_types; use spin_world::v2::postgres::{self as v2, Connection}; use spin_world::v2::rdbms_types; -use spin_world::v2::rdbms_types::{Column, DbDataType, DbValue, ParameterValue, RowSet}; -use tokio_postgres::{ - config::SslMode, - types::{ToSql, Type}, - Client, NoTls, Row, Socket, -}; +use spin_world::v2::rdbms_types::{ParameterValue, RowSet}; use tracing::instrument; use tracing::Level; +use crate::client::Client; use crate::InstanceState; -impl InstanceState { +impl InstanceState { async fn open_connection(&mut self, address: &str) -> Result, v2::Error> { self.connections .push( - build_client(address) + C::build_client(address) .await .map_err(|e| v2::Error::ConnectionFailed(format!("{e:?}")))?, ) @@ -29,7 +23,7 @@ impl InstanceState { .map(Resource::new_own) } - async fn get_client(&mut self, connection: Resource) -> Result<&Client, v2::Error> { + async fn get_client(&mut self, connection: Resource) -> Result<&C, v2::Error> { self.connections .get(connection.rep()) .ok_or_else(|| v2::Error::ConnectionFailed("no connection found".into())) @@ -52,7 +46,6 @@ impl InstanceState { .or_else(|| if ports.len() == 1 { ports.get(1) } else { None }); let port_str = port.map(|p| format!(":{}", p)).unwrap_or_default(); let url = format!("{address}{port_str}"); - // TODO: Should I be unwrapping this? if !self.allowed_hosts.check_url(&url, "postgres").await? { return Ok(false); } @@ -66,10 +59,10 @@ impl InstanceState { } #[async_trait] -impl v2::Host for InstanceState {} +impl v2::Host for InstanceState {} #[async_trait] -impl v2::HostConnection for InstanceState { +impl v2::HostConnection for InstanceState { #[instrument(name = "spin_outbound_pg.open_connection", skip(self), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql"))] async fn open(&mut self, address: String) -> Result, v2::Error> { if !self @@ -91,20 +84,11 @@ impl v2::HostConnection for InstanceState { statement: String, params: Vec, ) -> Result { - let params: Vec<&(dyn ToSql + Sync)> = params - .iter() - .map(to_sql_parameter) - .collect::>>() - .map_err(|e| v2::Error::ValueConversionFailed(format!("{:?}", e)))?; - - let nrow = self + Ok(self .get_client(connection) .await? - .execute(&statement, params.as_slice()) - .await - .map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?; - - Ok(nrow) + .execute(statement, params) + .await?) } #[instrument(name = "spin_outbound_pg.query", skip(self, connection), err(level = Level::INFO), fields(otel.kind = "client", db.system = "postgresql", otel.name = statement))] @@ -114,34 +98,11 @@ impl v2::HostConnection for InstanceState { statement: String, params: Vec, ) -> Result { - let params: Vec<&(dyn ToSql + Sync)> = params - .iter() - .map(to_sql_parameter) - .collect::>>() - .map_err(|e| v2::Error::BadParameter(format!("{:?}", e)))?; - - let results = self + Ok(self .get_client(connection) .await? - .query(&statement, params.as_slice()) - .await - .map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?; - - if results.is_empty() { - return Ok(RowSet { - columns: vec![], - rows: vec![], - }); - } - - let columns = infer_columns(&results[0]); - let rows = results - .iter() - .map(convert_row) - .collect::, _>>() - .map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?; - - Ok(RowSet { columns, rows }) + .query(statement, params) + .await?) } fn drop(&mut self, connection: Resource) -> anyhow::Result<()> { @@ -150,225 +111,12 @@ impl v2::HostConnection for InstanceState { } } -impl rdbms_types::Host for InstanceState { +impl rdbms_types::Host for InstanceState { fn convert_error(&mut self, error: v2::Error) -> Result { Ok(error) } } -fn to_sql_parameter(value: &ParameterValue) -> anyhow::Result<&(dyn ToSql + Sync)> { - match value { - ParameterValue::Boolean(v) => Ok(v), - ParameterValue::Int32(v) => Ok(v), - ParameterValue::Int64(v) => Ok(v), - ParameterValue::Int8(v) => Ok(v), - ParameterValue::Int16(v) => Ok(v), - ParameterValue::Floating32(v) => Ok(v), - ParameterValue::Floating64(v) => Ok(v), - ParameterValue::Uint8(_) - | ParameterValue::Uint16(_) - | ParameterValue::Uint32(_) - | ParameterValue::Uint64(_) => Err(anyhow!("Postgres does not support unsigned integers")), - ParameterValue::Str(v) => Ok(v), - ParameterValue::Binary(v) => Ok(v), - ParameterValue::DbNull => Ok(&PgNull), - } -} - -fn infer_columns(row: &Row) -> Vec { - let mut result = Vec::with_capacity(row.len()); - for index in 0..row.len() { - result.push(infer_column(row, index)); - } - result -} - -fn infer_column(row: &Row, index: usize) -> Column { - let column = &row.columns()[index]; - let name = column.name().to_owned(); - let data_type = convert_data_type(column.type_()); - Column { name, data_type } -} - -fn convert_data_type(pg_type: &Type) -> DbDataType { - match *pg_type { - Type::BOOL => DbDataType::Boolean, - Type::BYTEA => DbDataType::Binary, - Type::FLOAT4 => DbDataType::Floating32, - Type::FLOAT8 => DbDataType::Floating64, - Type::INT2 => DbDataType::Int16, - Type::INT4 => DbDataType::Int32, - Type::INT8 => DbDataType::Int64, - Type::TEXT | Type::VARCHAR | Type::BPCHAR => DbDataType::Str, - _ => { - tracing::debug!("Couldn't convert Postgres type {} to WIT", pg_type.name(),); - DbDataType::Other - } - } -} - -fn convert_row(row: &Row) -> Result, tokio_postgres::Error> { - let mut result = Vec::with_capacity(row.len()); - for index in 0..row.len() { - result.push(convert_entry(row, index)?); - } - Ok(result) -} - -fn convert_entry(row: &Row, index: usize) -> Result { - let column = &row.columns()[index]; - let value = match column.type_() { - &Type::BOOL => { - let value: Option = row.try_get(index)?; - match value { - Some(v) => DbValue::Boolean(v), - None => DbValue::DbNull, - } - } - &Type::BYTEA => { - let value: Option> = row.try_get(index)?; - match value { - Some(v) => DbValue::Binary(v), - None => DbValue::DbNull, - } - } - &Type::FLOAT4 => { - let value: Option = row.try_get(index)?; - match value { - Some(v) => DbValue::Floating32(v), - None => DbValue::DbNull, - } - } - &Type::FLOAT8 => { - let value: Option = row.try_get(index)?; - match value { - Some(v) => DbValue::Floating64(v), - None => DbValue::DbNull, - } - } - &Type::INT2 => { - let value: Option = row.try_get(index)?; - match value { - Some(v) => DbValue::Int16(v), - None => DbValue::DbNull, - } - } - &Type::INT4 => { - let value: Option = row.try_get(index)?; - match value { - Some(v) => DbValue::Int32(v), - None => DbValue::DbNull, - } - } - &Type::INT8 => { - let value: Option = row.try_get(index)?; - match value { - Some(v) => DbValue::Int64(v), - None => DbValue::DbNull, - } - } - &Type::TEXT | &Type::VARCHAR | &Type::BPCHAR => { - let value: Option = row.try_get(index)?; - match value { - Some(v) => DbValue::Str(v), - None => DbValue::DbNull, - } - } - t => { - tracing::debug!( - "Couldn't convert Postgres type {} in column {}", - t.name(), - column.name() - ); - DbValue::Unsupported - } - }; - Ok(value) -} - -async fn build_client(address: &str) -> anyhow::Result { - let config = address.parse::()?; - - tracing::debug!("Build new connection: {}", address); - - if config.get_ssl_mode() == SslMode::Disable { - connect(config).await - } else { - connect_tls(config).await - } -} - -async fn connect(config: tokio_postgres::Config) -> anyhow::Result { - let (client, connection) = config.connect(NoTls).await?; - - spawn(connection); - - Ok(client) -} - -async fn connect_tls(config: tokio_postgres::Config) -> anyhow::Result { - let builder = TlsConnector::builder(); - let connector = MakeTlsConnector::new(builder.build()?); - let (client, connection) = config.connect(connector).await?; - - spawn(connection); - - Ok(client) -} - -fn spawn(connection: tokio_postgres::Connection) -where - T: tokio_postgres::tls::TlsStream + std::marker::Unpin + std::marker::Send + 'static, -{ - tokio::spawn(async move { - if let Err(e) = connection.await { - tracing::error!("Postgres connection error: {}", e); - } - }); -} - -/// Although the Postgres crate converts Rust Option::None to Postgres NULL, -/// it enforces the type of the Option as it does so. (For example, trying to -/// pass an Option::::None to a VARCHAR column fails conversion.) As we -/// do not know expected column types, we instead use a "neutral" custom type -/// which allows conversion to any type but always tells the Postgres crate to -/// treat it as a SQL NULL. -struct PgNull; - -impl ToSql for PgNull { - fn to_sql( - &self, - _ty: &Type, - _out: &mut tokio_postgres::types::private::BytesMut, - ) -> Result> - where - Self: Sized, - { - Ok(tokio_postgres::types::IsNull::Yes) - } - - fn accepts(_ty: &Type) -> bool - where - Self: Sized, - { - true - } - - fn to_sql_checked( - &self, - _ty: &Type, - _out: &mut tokio_postgres::types::private::BytesMut, - ) -> Result> { - Ok(tokio_postgres::types::IsNull::Yes) - } -} - -impl std::fmt::Debug for PgNull { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - f.debug_struct("NULL").finish() - } -} - /// Delegate a function call to the v2::HostConnection implementation macro_rules! delegate { ($self:ident.$name:ident($address:expr, $($arg:expr),*)) => {{ @@ -388,7 +136,7 @@ macro_rules! delegate { } #[async_trait] -impl v1::Host for InstanceState { +impl v1::Host for InstanceState { async fn execute( &mut self, address: String, diff --git a/crates/factor-outbound-pg/src/lib.rs b/crates/factor-outbound-pg/src/lib.rs index 143666932..484cc68c3 100644 --- a/crates/factor-outbound-pg/src/lib.rs +++ b/crates/factor-outbound-pg/src/lib.rs @@ -1,18 +1,22 @@ +pub mod client; mod host; +use client::Client; use spin_factor_outbound_networking::{OutboundAllowedHosts, OutboundNetworkingFactor}; use spin_factors::{ anyhow, ConfigureAppContext, Factor, InstanceBuilders, PrepareContext, RuntimeFactors, SelfInstanceBuilder, }; -use tokio_postgres::Client; +use tokio_postgres::Client as PgClient; -pub struct OutboundPgFactor; +pub struct OutboundPgFactor { + _phantom: std::marker::PhantomData, +} -impl Factor for OutboundPgFactor { +impl Factor for OutboundPgFactor { type RuntimeConfig = (); type AppState = (); - type InstanceBuilder = InstanceState; + type InstanceBuilder = InstanceState; fn init( &mut self, @@ -45,9 +49,23 @@ impl Factor for OutboundPgFactor { } } -pub struct InstanceState { +impl Default for OutboundPgFactor { + fn default() -> Self { + Self { + _phantom: Default::default(), + } + } +} + +impl OutboundPgFactor { + pub fn new() -> Self { + Self::default() + } +} + +pub struct InstanceState { allowed_hosts: OutboundAllowedHosts, - connections: table::Table, + connections: table::Table, } -impl SelfInstanceBuilder for InstanceState {} +impl SelfInstanceBuilder for InstanceState {} diff --git a/crates/factor-outbound-pg/tests/factor_test.rs b/crates/factor-outbound-pg/tests/factor_test.rs index 4f2f78852..07f47cc0c 100644 --- a/crates/factor-outbound-pg/tests/factor_test.rs +++ b/crates/factor-outbound-pg/tests/factor_test.rs @@ -1,39 +1,48 @@ -use anyhow::bail; +use anyhow::{bail, Result}; use spin_factor_outbound_networking::OutboundNetworkingFactor; +use spin_factor_outbound_pg::client::Client; use spin_factor_outbound_pg::OutboundPgFactor; use spin_factor_variables::{StaticVariables, VariablesFactor}; -use spin_factor_wasi::{DummyFilesMounter, WasiFactor}; use spin_factors::{anyhow, RuntimeFactors}; use spin_factors_test::{toml, TestEnvironment}; +use spin_world::async_trait; use spin_world::v2::postgres::HostConnection; +use spin_world::v2::postgres::{self as v2}; use spin_world::v2::rdbms_types::Error as PgError; +use spin_world::v2::rdbms_types::{ParameterValue, RowSet}; #[derive(RuntimeFactors)] struct TestFactors { - wasi: WasiFactor, variables: VariablesFactor, networking: OutboundNetworkingFactor, - pg: OutboundPgFactor, + pg: OutboundPgFactor, +} + +fn factors() -> Result { + let mut f = TestFactors { + variables: VariablesFactor::default(), + networking: OutboundNetworkingFactor, + pg: OutboundPgFactor::::new(), + }; + f.variables.add_provider_type(StaticVariables)?; + Ok(f) } fn test_env() -> TestEnvironment { TestEnvironment::default_manifest_extend(toml! { [component.test-component] source = "does-not-exist.wasm" + allowed_outbound_hosts = ["postgres://*:*"] }) } #[tokio::test] async fn disallowed_host_fails() -> anyhow::Result<()> { - let mut factors = TestFactors { - wasi: WasiFactor::new(DummyFilesMounter), - variables: VariablesFactor::default(), - networking: OutboundNetworkingFactor, - pg: OutboundPgFactor, - }; - factors.variables.add_provider_type(StaticVariables)?; - - let env = test_env(); + let factors = factors()?; + let env = TestEnvironment::default_manifest_extend(toml! { + [component.test-component] + source = "does-not-exist.wasm" + }); let mut state = env.build_instance_state(factors).await?; let res = state @@ -43,8 +52,94 @@ async fn disallowed_host_fails() -> anyhow::Result<()> { let Err(err) = res else { bail!("expected Err, got Ok"); }; - println!("err: {:?}", err); assert!(matches!(err, PgError::ConnectionFailed(_))); Ok(()) } + +#[tokio::test] +async fn allowed_host_succeeds() -> anyhow::Result<()> { + let factors = factors()?; + let env = test_env(); + let mut state = env.build_instance_state(factors).await?; + + let res = state + .pg + .open("postgres://localhost:5432/test".to_string()) + .await; + let Ok(_) = res else { + bail!("expected Ok, got Err"); + }; + + Ok(()) +} + +#[tokio::test] +async fn exercise_execute() -> anyhow::Result<()> { + let factors = factors()?; + let env = test_env(); + let mut state = env.build_instance_state(factors).await?; + + let connection = state + .pg + .open("postgres://localhost:5432/test".to_string()) + .await?; + + state + .pg + .execute(connection, "SELECT * FROM test".to_string(), vec![]) + .await?; + + Ok(()) +} + +#[tokio::test] +async fn exercise_query() -> anyhow::Result<()> { + let factors = factors()?; + let env = test_env(); + let mut state = env.build_instance_state(factors).await?; + + let connection = state + .pg + .open("postgres://localhost:5432/test".to_string()) + .await?; + + state + .pg + .query(connection, "SELECT * FROM test".to_string(), vec![]) + .await?; + + Ok(()) +} + +// TODO: We can expand this mock to track calls and simulate return values +pub struct MockClient {} + +#[async_trait] +impl Client for MockClient { + async fn build_client(_address: &str) -> anyhow::Result + where + Self: Sized, + { + Ok(MockClient {}) + } + + async fn execute( + &self, + _statement: String, + _params: Vec, + ) -> Result { + Ok(0) + } + + async fn query( + &self, + _statement: String, + _params: Vec, + ) -> Result { + Ok(RowSet { + columns: vec![], + rows: vec![], + }) + } +}