Skip to content

Commit

Permalink
fix: handle native types for joined queries (#4546)
Browse files Browse the repository at this point in the history
  • Loading branch information
Weakky committed Dec 18, 2023
1 parent cc4d187 commit 0ca5ccb
Show file tree
Hide file tree
Showing 15 changed files with 518 additions and 17 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions psl/builtin-connectors/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@ indoc.workspace = true
lsp-types = "0.91.1"
once_cell = "1.3"
regex = "1"
chrono = { version = "0.4.6", default-features = false }

24 changes: 24 additions & 0 deletions psl/builtin-connectors/src/cockroach_datamodel_connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ mod validations;

pub use native_types::CockroachType;

use chrono::*;
use enumflags2::BitFlags;
use lsp_types::{CompletionItem, CompletionItemKind, CompletionList};
use psl_core::{
Expand Down Expand Up @@ -307,6 +308,29 @@ impl Connector for CockroachDatamodelConnector {
fn flavour(&self) -> Flavour {
Flavour::Cockroach
}

fn parse_json_datetime(
&self,
str: &str,
nt: Option<NativeTypeInstance>,
) -> chrono::ParseResult<chrono::DateTime<FixedOffset>> {
let native_type: Option<&CockroachType> = nt.as_ref().map(|nt| nt.downcast_ref());

match native_type {
Some(ct) => match ct {
CockroachType::Timestamptz(_) => crate::utils::parse_timestamptz(str),
CockroachType::Timestamp(_) => crate::utils::parse_timestamp(str),
CockroachType::Date => crate::utils::parse_date(str),
CockroachType::Time(_) => crate::utils::parse_time(str),
CockroachType::Timetz(_) => crate::utils::parse_timetz(str),
_ => unreachable!(),
},
None => self.parse_json_datetime(
str,
Some(self.default_native_type_for_scalar_type(&ScalarType::DateTime)),
),
}
}
}

/// An `@default(sequence())` function.
Expand Down
1 change: 1 addition & 0 deletions psl/builtin-connectors/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod mysql_datamodel_connector;
mod native_type_definition;
mod postgres_datamodel_connector;
mod sqlite_datamodel_connector;
mod utils;

use psl_core::{datamodel_connector::Connector, ConnectorRegistry};

Expand Down
24 changes: 24 additions & 0 deletions psl/builtin-connectors/src/postgres_datamodel_connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ mod validations;

pub use native_types::PostgresType;

use chrono::*;
use enumflags2::BitFlags;
use lsp_types::{CompletionItem, CompletionItemKind, CompletionList, InsertTextFormat};
use psl_core::{
Expand Down Expand Up @@ -567,6 +568,29 @@ impl Connector for PostgresDatamodelConnector {
fn flavour(&self) -> Flavour {
Flavour::Postgres
}

fn parse_json_datetime(
&self,
str: &str,
nt: Option<NativeTypeInstance>,
) -> chrono::ParseResult<chrono::DateTime<FixedOffset>> {
let native_type: Option<&PostgresType> = nt.as_ref().map(|nt| nt.downcast_ref());

match native_type {
Some(pt) => match pt {
Timestamptz(_) => crate::utils::parse_timestamptz(str),
Timestamp(_) => crate::utils::parse_timestamp(str),
Date => crate::utils::parse_date(str),
Time(_) => crate::utils::parse_time(str),
Timetz(_) => crate::utils::parse_timetz(str),
_ => unreachable!(),
},
None => self.parse_json_datetime(
str,
Some(self.default_native_type_for_scalar_type(&ScalarType::DateTime)),
),
}
}
}

fn allowed_index_operator_classes(algo: IndexAlgorithm, field: walkers::ScalarFieldWalker<'_>) -> Vec<OperatorClass> {
Expand Down
37 changes: 37 additions & 0 deletions psl/builtin-connectors/src/utils.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
use chrono::*;

pub(crate) fn parse_date(str: &str) -> Result<DateTime<FixedOffset>, chrono::ParseError> {
chrono::NaiveDate::parse_from_str(str, "%Y-%m-%d")
.map(|date| DateTime::<Utc>::from_utc(date.and_hms_opt(0, 0, 0).unwrap(), Utc))
.map(DateTime::<FixedOffset>::from)
}

pub(crate) fn parse_timestamptz(str: &str) -> Result<DateTime<FixedOffset>, chrono::ParseError> {
DateTime::parse_from_rfc3339(str)
}

pub(crate) fn parse_timestamp(str: &str) -> Result<DateTime<FixedOffset>, chrono::ParseError> {
NaiveDateTime::parse_from_str(str, "%Y-%m-%dT%H:%M:%S%.f")
.map(|dt| DateTime::from_utc(dt, Utc))
.or_else(|_| DateTime::parse_from_rfc3339(str).map(DateTime::<Utc>::from))
.map(DateTime::<FixedOffset>::from)
}

pub(crate) fn parse_time(str: &str) -> Result<DateTime<FixedOffset>, chrono::ParseError> {
chrono::NaiveTime::parse_from_str(str, "%H:%M:%S%.f")
.map(|time| {
let base_date = chrono::NaiveDate::from_ymd_opt(1970, 1, 1).unwrap();

DateTime::<Utc>::from_utc(base_date.and_time(time), Utc)
})
.map(DateTime::<FixedOffset>::from)
}

pub(crate) fn parse_timetz(str: &str) -> Result<DateTime<FixedOffset>, chrono::ParseError> {
// We currently don't support time with timezone.
// We strip the timezone information and parse it as a time.
// This is inline with what Quaint does already.
let time_without_tz = str.split('+').next().unwrap();

parse_time(time_without_tz)
}
9 changes: 9 additions & 0 deletions psl/psl-core/src/datamodel_connector.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ pub use self::{
};

use crate::{configuration::DatasourceConnectorData, Configuration, Datasource, PreviewFeature};
use chrono::{DateTime, FixedOffset};
use diagnostics::{DatamodelError, Diagnostics, NativeTypeErrorFactory, Span};
use enumflags2::BitFlags;
use lsp_types::CompletionList;
Expand Down Expand Up @@ -359,6 +360,14 @@ pub trait Connector: Send + Sync {
) -> DatasourceConnectorData {
Default::default()
}

fn parse_json_datetime(
&self,
_str: &str,
_nt: Option<NativeTypeInstance>,
) -> chrono::ParseResult<DateTime<FixedOffset>> {
unreachable!("This method is only implemented on connectors with lateral join support.")
}
}

#[derive(Copy, Clone, Debug, PartialEq)]
Expand Down
9 changes: 8 additions & 1 deletion quaint/src/ast/column.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use super::Aliasable;
use super::{values::NativeColumnType, Aliasable};
use crate::{
ast::{Expression, ExpressionKind, Table},
Value,
Expand Down Expand Up @@ -32,6 +32,8 @@ pub struct Column<'a> {
pub(crate) alias: Option<Cow<'a, str>>,
pub(crate) default: Option<DefaultValue<'a>>,
pub(crate) type_family: Option<TypeFamily>,
/// The underlying native type of the column.
pub(crate) native_type: Option<NativeColumnType<'a>>,
/// Whether the column is an enum.
pub(crate) is_enum: bool,
/// Whether the column is a (scalar) list.
Expand Down Expand Up @@ -130,6 +132,11 @@ impl<'a> Column<'a> {
.map(|d| d == &DefaultValue::Generated)
.unwrap_or(false)
}

pub fn native_column_type<T: Into<NativeColumnType<'a>>>(mut self, native_type: Option<T>) -> Column<'a> {
self.native_type = native_type.map(|nt| nt.into());
self
}
}

impl<'a> From<Column<'a>> for Expression<'a> {
Expand Down
34 changes: 33 additions & 1 deletion quaint/src/visitor/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,23 @@ pub struct Postgres<'a> {
parameters: Vec<Value<'a>>,
}

impl<'a> Postgres<'a> {
fn visit_json_build_obj_expr(&mut self, expr: Expression<'a>) -> crate::Result<()> {
match expr.kind() {
ExpressionKind::Column(col) => match (col.type_family.as_ref(), col.native_type.as_deref()) {
(Some(TypeFamily::Decimal(_)), Some("MONEY")) => {
self.visit_expression(expr)?;
self.write("::numeric")?;

Ok(())
}
_ => self.visit_expression(expr),
},
_ => self.visit_expression(expr),
}
}
}

impl<'a> Visitor<'a> for Postgres<'a> {
const C_BACKTICK_OPEN: &'static str = "\"";
const C_BACKTICK_CLOSE: &'static str = "\"";
Expand Down Expand Up @@ -534,7 +551,7 @@ impl<'a> Visitor<'a> for Postgres<'a> {
while let Some((name, expr)) = chunk.next() {
s.visit_raw_value(Value::text(name))?;
s.write(", ")?;
s.visit_expression(expr)?;
s.visit_json_build_obj_expr(expr)?;
if chunk.peek().is_some() {
s.write(", ")?;
}
Expand Down Expand Up @@ -1290,6 +1307,21 @@ mod tests {
);
}

#[test]
fn money() {
let build_json = json_build_object(vec![(
"money".into(),
Column::from("money")
.native_column_type(Some("money"))
.type_family(TypeFamily::Decimal(None))
.into(),
)]);
let query = Select::default().value(build_json);
let (sql, _) = Postgres::build(query).unwrap();

assert_eq!(sql, "SELECT JSONB_BUILD_OBJECT('money', \"money\"::numeric)");
}

fn build_json_object(num_fields: u32) -> JsonBuildObject<'static> {
let fields = (1..=num_fields)
.map(|i| (format!("f{i}").into(), Expression::from(i as i64)))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ mod enum_type;
mod float;
mod int;
mod json;
mod native;
mod string;
mod through_relation;
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mod postgres;
Loading

0 comments on commit 0ca5ccb

Please sign in to comment.