Skip to content

Commit

Permalink
fix: do not cast enum to text within min/max (#4453)
Browse files Browse the repository at this point in the history
  • Loading branch information
Weakky authored Nov 27, 2023
1 parent 31deaad commit 0a59e97
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 15 deletions.
20 changes: 16 additions & 4 deletions quaint/src/visitor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1004,6 +1004,20 @@ pub trait Visitor<'a> {
Ok(())
}

fn visit_min(&mut self, min: Minimum<'a>) -> Result {
self.write("MIN")?;
self.surround_with("(", ")", |ref mut s| s.visit_column(min.column))?;

Ok(())
}

fn visit_max(&mut self, max: Maximum<'a>) -> Result {
self.write("MAX")?;
self.surround_with("(", ")", |ref mut s| s.visit_column(max.column))?;

Ok(())
}

fn visit_function(&mut self, fun: Function<'a>) -> Result {
match fun.typ_ {
FunctionType::RowNumber(fun_rownum) => {
Expand Down Expand Up @@ -1046,12 +1060,10 @@ pub trait Visitor<'a> {
self.surround_with("(", ")", |ref mut s| s.visit_expression(*upper.expression))?;
}
FunctionType::Minimum(min) => {
self.write("MIN")?;
self.surround_with("(", ")", |ref mut s| s.visit_column(min.column))?;
self.visit_min(min)?;
}
FunctionType::Maximum(max) => {
self.write("MAX")?;
self.surround_with("(", ")", |ref mut s| s.visit_column(max.column))?;
self.visit_max(max)?;
}
FunctionType::Coalesce(coalesce) => {
self.write("COALESCE")?;
Expand Down
39 changes: 39 additions & 0 deletions quaint/src/visitor/postgres.rs
Original file line number Diff line number Diff line change
Expand Up @@ -627,6 +627,34 @@ impl<'a> Visitor<'a> for Postgres<'a> {

Ok(())
}

fn visit_min(&mut self, min: Minimum<'a>) -> visitor::Result {
// If the inner column is a selected enum, then we cast the result of MIN(enum)::text instead of casting the inner enum column, which changes the behavior of MIN.
let should_cast = min.column.is_enum && min.column.is_selected;

self.write("MIN")?;
self.surround_with("(", ")", |ref mut s| s.visit_column(min.column.set_is_selected(false)))?;

if should_cast {
self.write("::text")?;
}

Ok(())
}

fn visit_max(&mut self, max: Maximum<'a>) -> visitor::Result {
// If the inner column is a selected enum, then we cast the result of MAX(enum)::text instead of casting the inner enum column, which changes the behavior of MAX.
let should_cast = max.column.is_enum && max.column.is_selected;

self.write("MAX")?;
self.surround_with("(", ")", |ref mut s| s.visit_column(max.column.set_is_selected(false)))?;

if should_cast {
self.write("::text")?;
}

Ok(())
}
}

#[cfg(test)]
Expand Down Expand Up @@ -1157,4 +1185,15 @@ mod tests {

assert_eq!("SELECT \"User\".*, \"Toto\".* FROM \"User\" LEFT JOIN \"Post\" AS \"p\" ON \"p\".\"userId\" = \"User\".\"id\", \"Toto\"", sql);
}

#[test]
fn enum_cast_text_in_min_max_should_be_outside() {
let enum_col = Column::from("enum").set_is_enum(true).set_is_selected(true);
let q = Select::from_table("User")
.value(min(enum_col.clone()))
.value(max(enum_col));
let (sql, _) = Postgres::build(q).unwrap();

assert_eq!("SELECT MIN(\"enum\")::text, MAX(\"enum\")::text FROM \"User\"", sql);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -515,6 +515,54 @@ mod aggregation_group_by {
Ok(())
}

fn schema_21789() -> String {
let schema = indoc! {
r#"model Test {
#id(id, Int, @id)
group Int
color Color
}
enum Color {
blue
red
green
}
"#
};

schema.to_owned()
}

// regression test for https://github.com/prisma/prisma/issues/21789
#[connector_test(schema(schema_21789), only(Postgres, CockroachDB))]
async fn regression_21789(runner: Runner) -> TestResult<()> {
run_query!(
&runner,
r#"mutation { createOneTest(data: { id: 1, group: 1, color: "red" }) { id } }"#
);
run_query!(
&runner,
r#"mutation { createOneTest(data: { id: 2, group: 2, color: "green" }) { id } }"#
);
run_query!(
&runner,
r#"mutation { createOneTest(data: { id: 3, group: 1, color: "blue" }) { id } }"#
);

insta::assert_snapshot!(
run_query!(&runner, r#"{ aggregateTest { _max { color } _min { color } } }"#),
@r###"{"data":{"aggregateTest":{"_max":{"color":"green"},"_min":{"color":"blue"}}}}"###
);

insta::assert_snapshot!(
run_query!(&runner, r#"{ groupByTest(by: [group], orderBy: { group: asc }) { group _max { color } _min { color } } }"#),
@r###"{"data":{"groupByTest":[{"group":1,"_max":{"color":"red"},"_min":{"color":"blue"}},{"group":2,"_max":{"color":"green"},"_min":{"color":"green"}}]}}"###
);

Ok(())
}

/// Error cases
#[connector_test]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,13 @@ pub(crate) async fn get_single_record(
aggr_selections: &[RelAggregationSelection],
ctx: &Context<'_>,
) -> crate::Result<Option<SingleRecord>> {
let query = read::get_records(model, selected_fields.as_columns(ctx), aggr_selections, filter, ctx);
let query = read::get_records(
model,
selected_fields.as_columns(ctx).mark_all_selected(),
aggr_selections,
filter,
ctx,
);

let mut field_names: Vec<_> = selected_fields.db_names().collect();
let mut aggr_field_names: Vec<_> = aggr_selections.iter().map(|aggr_sel| aggr_sel.db_alias()).collect();
Expand Down Expand Up @@ -104,7 +110,13 @@ pub(crate) async fn get_many_records(
let mut futures = FuturesUnordered::new();

for args in batches.into_iter() {
let query = read::get_records(model, selected_fields.as_columns(ctx), aggr_selections, args, ctx);
let query = read::get_records(
model,
selected_fields.as_columns(ctx).mark_all_selected(),
aggr_selections,
args,
ctx,
);

futures.push(conn.filter(query.into(), meta.as_slice(), ctx));
}
Expand All @@ -122,7 +134,7 @@ pub(crate) async fn get_many_records(
_ => {
let query = read::get_records(
model,
selected_fields.as_columns(ctx),
selected_fields.as_columns(ctx).mark_all_selected(),
aggr_selections,
query_arguments,
ctx,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ pub struct ColumnIterator {
inner: Box<dyn Iterator<Item = Column<'static>> + 'static>,
}

impl ColumnIterator {
/// Sets all columns as selected. This is a hack that we use to help the Postgres SQL visitor cast enum columns to text to avoid some driver roundtrips otherwise needed to resolve enum types.
pub fn mark_all_selected(self) -> Self {
ColumnIterator {
inner: Box::new(self.inner.map(|c| c.set_is_selected(true))),
}
}
}

impl Iterator for ColumnIterator {
type Item = Column<'static>;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,7 @@ where
T: SelectDefinition,
{
let (select, additional_selection_set) = query.into_select(model, aggr_selections, ctx);
let select = columns
.map(|c| c.set_is_selected(true))
.fold(select, |acc, col| acc.column(col));
let select = columns.fold(select, |acc, col| acc.column(col));

let select = select.append_trace(&Span::current()).add_trace_id(ctx.trace_id);

Expand Down Expand Up @@ -176,7 +174,11 @@ pub(crate) fn aggregate(
.append_trace(&Span::current())
.add_trace_id(ctx.trace_id),
|select, next_op| match next_op {
AggregationSelection::Field(field) => select.column(Column::from(field.db_name().to_owned())),
AggregationSelection::Field(field) => select.column(
Column::from(field.db_name().to_owned())
.set_is_enum(field.type_identifier().is_enum())
.set_is_selected(true),
),

AggregationSelection::Count { all, fields } => {
let select = fields.iter().fold(select, |select, next_field| {
Expand All @@ -199,11 +201,15 @@ pub(crate) fn aggregate(
}),

AggregationSelection::Min(fields) => fields.iter().fold(select, |select, next_field| {
select.value(min(Column::from(next_field.db_name().to_owned())))
select.value(min(Column::from(next_field.db_name().to_owned())
.set_is_enum(next_field.type_identifier().is_enum())
.set_is_selected(true)))
}),

AggregationSelection::Max(fields) => fields.iter().fold(select, |select, next_field| {
select.value(max(Column::from(next_field.db_name().to_owned())))
select.value(max(Column::from(next_field.db_name().to_owned())
.set_is_enum(next_field.type_identifier().is_enum())
.set_is_selected(true)))
}),
},
)
Expand Down Expand Up @@ -243,11 +249,11 @@ pub(crate) fn group_by_aggregate(
}),

AggregationSelection::Min(fields) => fields.iter().fold(select, |select, next_field| {
select.value(min(next_field.as_column(ctx)))
select.value(min(next_field.as_column(ctx).set_is_selected(true)))
}),

AggregationSelection::Max(fields) => fields.iter().fold(select, |select, next_field| {
select.value(max(next_field.as_column(ctx)))
select.value(max(next_field.as_column(ctx).set_is_selected(true)))
}),
});

Expand Down

0 comments on commit 0a59e97

Please sign in to comment.