Skip to content

Commit

Permalink
Merge pull request #2651 from calebschoepp/abstract-pg-client
Browse files Browse the repository at this point in the history
Abstract pg client
  • Loading branch information
lann authored Jul 25, 2024
2 parents 2d709df + 5849e9f commit 774ac14
Show file tree
Hide file tree
Showing 6 changed files with 434 additions and 291 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion crates/factor-outbound-pg/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ tracing = { workspace = true }

[dev-dependencies]
spin-factor-variables = { path = "../factor-variables" }
spin-factor-wasi = { path = "../factor-wasi" }
spin-factors-test = { path = "../factors-test" }
tokio = { version = "1", features = ["macros", "rt"] }

Expand Down
284 changes: 284 additions & 0 deletions crates/factor-outbound-pg/src/client.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,284 @@
use anyhow::{anyhow, Result};
use native_tls::TlsConnector;
use postgres_native_tls::MakeTlsConnector;
use spin_world::async_trait;
use spin_world::v2::postgres::{self as v2};
use spin_world::v2::rdbms_types::{Column, DbDataType, DbValue, ParameterValue, RowSet};
use tokio_postgres::types::Type;
use tokio_postgres::{config::SslMode, types::ToSql, Row};
use tokio_postgres::{Client as TokioClient, NoTls, Socket};

#[async_trait]
pub trait Client {
async fn build_client(address: &str) -> Result<Self>
where
Self: Sized;

async fn execute(
&self,
statement: String,
params: Vec<ParameterValue>,
) -> Result<u64, v2::Error>;

async fn query(
&self,
statement: String,
params: Vec<ParameterValue>,
) -> Result<RowSet, v2::Error>;
}

#[async_trait]
impl Client for TokioClient {
async fn build_client(address: &str) -> Result<Self>
where
Self: Sized,
{
let config = address.parse::<tokio_postgres::Config>()?;

tracing::debug!("Build new connection: {}", address);

if config.get_ssl_mode() == SslMode::Disable {
let (client, connection) = config.connect(NoTls).await?;
spawn_connection(connection);
Ok(client)
} else {
let builder = TlsConnector::builder();
let connector = MakeTlsConnector::new(builder.build()?);
let (client, connection) = config.connect(connector).await?;
spawn_connection(connection);
Ok(client)
}
}

async fn execute(
&self,
statement: String,
params: Vec<ParameterValue>,
) -> Result<u64, v2::Error> {
let params: Vec<&(dyn ToSql + Sync)> = params
.iter()
.map(to_sql_parameter)
.collect::<Result<Vec<_>>>()
.map_err(|e| v2::Error::ValueConversionFailed(format!("{:?}", e)))?;

self.execute(&statement, params.as_slice())
.await
.map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))
}

async fn query(
&self,
statement: String,
params: Vec<ParameterValue>,
) -> Result<RowSet, v2::Error> {
let params: Vec<&(dyn ToSql + Sync)> = params
.iter()
.map(to_sql_parameter)
.collect::<Result<Vec<_>>>()
.map_err(|e| v2::Error::BadParameter(format!("{:?}", e)))?;

let results = self
.query(&statement, params.as_slice())
.await
.map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?;

if results.is_empty() {
return Ok(RowSet {
columns: vec![],
rows: vec![],
});
}

let columns = infer_columns(&results[0]);
let rows = results
.iter()
.map(convert_row)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| v2::Error::QueryFailed(format!("{:?}", e)))?;

Ok(RowSet { columns, rows })
}
}

fn spawn_connection<T>(connection: tokio_postgres::Connection<Socket, T>)
where
T: tokio_postgres::tls::TlsStream + std::marker::Unpin + std::marker::Send + 'static,
{
tokio::spawn(async move {
if let Err(e) = connection.await {
tracing::error!("Postgres connection error: {}", e);
}
});
}

fn to_sql_parameter(value: &ParameterValue) -> Result<&(dyn ToSql + Sync)> {
match value {
ParameterValue::Boolean(v) => Ok(v),
ParameterValue::Int32(v) => Ok(v),
ParameterValue::Int64(v) => Ok(v),
ParameterValue::Int8(v) => Ok(v),
ParameterValue::Int16(v) => Ok(v),
ParameterValue::Floating32(v) => Ok(v),
ParameterValue::Floating64(v) => Ok(v),
ParameterValue::Uint8(_)
| ParameterValue::Uint16(_)
| ParameterValue::Uint32(_)
| ParameterValue::Uint64(_) => Err(anyhow!("Postgres does not support unsigned integers")),
ParameterValue::Str(v) => Ok(v),
ParameterValue::Binary(v) => Ok(v),
ParameterValue::DbNull => Ok(&PgNull),
}
}

fn infer_columns(row: &Row) -> Vec<Column> {
let mut result = Vec::with_capacity(row.len());
for index in 0..row.len() {
result.push(infer_column(row, index));
}
result
}

fn infer_column(row: &Row, index: usize) -> Column {
let column = &row.columns()[index];
let name = column.name().to_owned();
let data_type = convert_data_type(column.type_());
Column { name, data_type }
}

fn convert_data_type(pg_type: &Type) -> DbDataType {
match *pg_type {
Type::BOOL => DbDataType::Boolean,
Type::BYTEA => DbDataType::Binary,
Type::FLOAT4 => DbDataType::Floating32,
Type::FLOAT8 => DbDataType::Floating64,
Type::INT2 => DbDataType::Int16,
Type::INT4 => DbDataType::Int32,
Type::INT8 => DbDataType::Int64,
Type::TEXT | Type::VARCHAR | Type::BPCHAR => DbDataType::Str,
_ => {
tracing::debug!("Couldn't convert Postgres type {} to WIT", pg_type.name(),);
DbDataType::Other
}
}
}

fn convert_row(row: &Row) -> Result<Vec<DbValue>, tokio_postgres::Error> {
let mut result = Vec::with_capacity(row.len());
for index in 0..row.len() {
result.push(convert_entry(row, index)?);
}
Ok(result)
}

fn convert_entry(row: &Row, index: usize) -> Result<DbValue, tokio_postgres::Error> {
let column = &row.columns()[index];
let value = match column.type_() {
&Type::BOOL => {
let value: Option<bool> = row.try_get(index)?;
match value {
Some(v) => DbValue::Boolean(v),
None => DbValue::DbNull,
}
}
&Type::BYTEA => {
let value: Option<Vec<u8>> = row.try_get(index)?;
match value {
Some(v) => DbValue::Binary(v),
None => DbValue::DbNull,
}
}
&Type::FLOAT4 => {
let value: Option<f32> = row.try_get(index)?;
match value {
Some(v) => DbValue::Floating32(v),
None => DbValue::DbNull,
}
}
&Type::FLOAT8 => {
let value: Option<f64> = row.try_get(index)?;
match value {
Some(v) => DbValue::Floating64(v),
None => DbValue::DbNull,
}
}
&Type::INT2 => {
let value: Option<i16> = row.try_get(index)?;
match value {
Some(v) => DbValue::Int16(v),
None => DbValue::DbNull,
}
}
&Type::INT4 => {
let value: Option<i32> = row.try_get(index)?;
match value {
Some(v) => DbValue::Int32(v),
None => DbValue::DbNull,
}
}
&Type::INT8 => {
let value: Option<i64> = row.try_get(index)?;
match value {
Some(v) => DbValue::Int64(v),
None => DbValue::DbNull,
}
}
&Type::TEXT | &Type::VARCHAR | &Type::BPCHAR => {
let value: Option<String> = row.try_get(index)?;
match value {
Some(v) => DbValue::Str(v),
None => DbValue::DbNull,
}
}
t => {
tracing::debug!(
"Couldn't convert Postgres type {} in column {}",
t.name(),
column.name()
);
DbValue::Unsupported
}
};
Ok(value)
}

/// Although the Postgres crate converts Rust Option::None to Postgres NULL,
/// it enforces the type of the Option as it does so. (For example, trying to
/// pass an Option::<i32>::None to a VARCHAR column fails conversion.) As we
/// do not know expected column types, we instead use a "neutral" custom type
/// which allows conversion to any type but always tells the Postgres crate to
/// treat it as a SQL NULL.
struct PgNull;

impl ToSql for PgNull {
fn to_sql(
&self,
_ty: &Type,
_out: &mut tokio_postgres::types::private::BytesMut,
) -> Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>>
where
Self: Sized,
{
Ok(tokio_postgres::types::IsNull::Yes)
}

fn accepts(_ty: &Type) -> bool
where
Self: Sized,
{
true
}

fn to_sql_checked(
&self,
_ty: &Type,
_out: &mut tokio_postgres::types::private::BytesMut,
) -> Result<tokio_postgres::types::IsNull, Box<dyn std::error::Error + Sync + Send>> {
Ok(tokio_postgres::types::IsNull::Yes)
}
}

impl std::fmt::Debug for PgNull {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("NULL").finish()
}
}
Loading

0 comments on commit 774ac14

Please sign in to comment.