diff --git a/quaint/src/visitor.rs b/quaint/src/visitor.rs index 8424bc7fbb2b..c205b49dd279 100644 --- a/quaint/src/visitor.rs +++ b/quaint/src/visitor.rs @@ -1004,6 +1004,20 @@ pub trait Visitor<'a> { Ok(()) } + fn visit_min(&mut self, min: Minimum<'a>) -> Result { + self.write("MIN")?; + self.surround_with("(", ")", |ref mut s| s.visit_column(min.column))?; + + Ok(()) + } + + fn visit_max(&mut self, max: Maximum<'a>) -> Result { + self.write("MAX")?; + self.surround_with("(", ")", |ref mut s| s.visit_column(max.column))?; + + Ok(()) + } + fn visit_function(&mut self, fun: Function<'a>) -> Result { match fun.typ_ { FunctionType::RowNumber(fun_rownum) => { @@ -1046,12 +1060,10 @@ pub trait Visitor<'a> { self.surround_with("(", ")", |ref mut s| s.visit_expression(*upper.expression))?; } FunctionType::Minimum(min) => { - self.write("MIN")?; - self.surround_with("(", ")", |ref mut s| s.visit_column(min.column))?; + self.visit_min(min)?; } FunctionType::Maximum(max) => { - self.write("MAX")?; - self.surround_with("(", ")", |ref mut s| s.visit_column(max.column))?; + self.visit_max(max)?; } FunctionType::Coalesce(coalesce) => { self.write("COALESCE")?; diff --git a/quaint/src/visitor/postgres.rs b/quaint/src/visitor/postgres.rs index fda8a6132037..740860e53979 100644 --- a/quaint/src/visitor/postgres.rs +++ b/quaint/src/visitor/postgres.rs @@ -633,6 +633,34 @@ impl<'a> Visitor<'a> for Postgres<'a> { Ok(()) } + + fn visit_min(&mut self, min: Minimum<'a>) -> visitor::Result { + // If the inner column is a selected enum, then we cast the result of MIN(enum)::text instead of casting the inner enum column, which changes the behavior of MIN. + let should_cast = min.column.is_enum && min.column.is_selected; + + self.write("MIN")?; + self.surround_with("(", ")", |ref mut s| s.visit_column(min.column.set_is_selected(false)))?; + + if should_cast { + self.write("::text")?; + } + + Ok(()) + } + + fn visit_max(&mut self, max: Maximum<'a>) -> visitor::Result { + // If the inner column is a selected enum, then we cast the result of MAX(enum)::text instead of casting the inner enum column, which changes the behavior of MAX. + let should_cast = max.column.is_enum && max.column.is_selected; + + self.write("MAX")?; + self.surround_with("(", ")", |ref mut s| s.visit_column(max.column.set_is_selected(false)))?; + + if should_cast { + self.write("::text")?; + } + + Ok(()) + } } #[cfg(test)] diff --git a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/group_by.rs b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/group_by.rs index 5abbbfe4bdf4..9fbc5da304a7 100644 --- a/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/group_by.rs +++ b/query-engine/connector-test-kit-rs/query-engine-tests/tests/queries/aggregation/group_by.rs @@ -515,6 +515,53 @@ mod aggregation_group_by { Ok(()) } + fn schema_21789() -> String { + let schema = indoc! { + r#"model Test { + #id(id, Int, @id) + color Color + } + + enum Color { + blue + red + green + } + "# + }; + + schema.to_owned() + } + + // regression test for https://github.com/prisma/prisma/issues/21789 + #[connector_test(schema(schema_21789), capabilities(Enums))] + async fn regression_21789(runner: Runner) -> TestResult<()> { + run_query!( + &runner, + r#"mutation { createOneTest(data: { id: 1, color: "red" }) { id } }"# + ); + run_query!( + &runner, + r#"mutation { createOneTest(data: { id: 2, color: "green" }) { id } }"# + ); + run_query!( + &runner, + r#"mutation { createOneTest(data: { id: 3, color: "blue" }) { id } }"# + ); + + insta::assert_snapshot!( + run_query!(&runner, r#"{ aggregateTest { _max { color } _min { color } } }"#), + @r###"{"data":{"aggregateTest":{"_max":{"color":"green"},"_min":{"color":"blue"}}}}"### + ); + + insta::assert_snapshot!( + run_query!(&runner, r#"{ groupByTest(by: [color]) { color _max { color } _min { color } } }"#), + @r###"{"data":{"groupByTest":[{"color":"green","_max":{"color":"green"},"_min":{"color":"green"}},{"color":"blue","_max":{"color":"blue"},"_min":{"color":"blue"}},{"color":"red","_max":{"color":"red"},"_min":{"color":"red"}}]}"### + ); + + Ok(()) + } + /// Error cases #[connector_test] diff --git a/query-engine/connectors/sql-query-connector/src/database/operations/read.rs b/query-engine/connectors/sql-query-connector/src/database/operations/read.rs index 470628de1132..8b3bf9031019 100644 --- a/query-engine/connectors/sql-query-connector/src/database/operations/read.rs +++ b/query-engine/connectors/sql-query-connector/src/database/operations/read.rs @@ -18,7 +18,13 @@ pub(crate) async fn get_single_record( aggr_selections: &[RelAggregationSelection], ctx: &Context<'_>, ) -> crate::Result> { - let query = read::get_records(model, selected_fields.as_columns(ctx), aggr_selections, filter, ctx); + let query = read::get_records( + model, + selected_fields.as_columns(ctx).mark_all_selected(), + aggr_selections, + filter, + ctx, + ); let mut field_names: Vec<_> = selected_fields.db_names().collect(); let mut aggr_field_names: Vec<_> = aggr_selections.iter().map(|aggr_sel| aggr_sel.db_alias()).collect(); @@ -104,7 +110,13 @@ pub(crate) async fn get_many_records( let mut futures = FuturesUnordered::new(); for args in batches.into_iter() { - let query = read::get_records(model, selected_fields.as_columns(ctx), aggr_selections, args, ctx); + let query = read::get_records( + model, + selected_fields.as_columns(ctx).mark_all_selected(), + aggr_selections, + args, + ctx, + ); futures.push(conn.filter(query.into(), meta.as_slice(), ctx)); } @@ -122,7 +134,7 @@ pub(crate) async fn get_many_records( _ => { let query = read::get_records( model, - selected_fields.as_columns(ctx), + selected_fields.as_columns(ctx).mark_all_selected(), aggr_selections, query_arguments, ctx, diff --git a/query-engine/connectors/sql-query-connector/src/model_extensions/column.rs b/query-engine/connectors/sql-query-connector/src/model_extensions/column.rs index 445bada9c45c..d3139082975b 100644 --- a/query-engine/connectors/sql-query-connector/src/model_extensions/column.rs +++ b/query-engine/connectors/sql-query-connector/src/model_extensions/column.rs @@ -7,6 +7,15 @@ pub struct ColumnIterator { inner: Box> + 'static>, } +impl ColumnIterator { + /// Sets all columns as selected. This is a hack that we use to help the Postgres SQL visitor cast enum columns to text to avoid some driver roundtrips otherwise needed to resolve enum types. + pub fn mark_all_selected(self) -> Self { + ColumnIterator { + inner: Box::new(self.inner.map(|c| c.set_is_selected(true))), + } + } +} + impl Iterator for ColumnIterator { type Item = Column<'static>; diff --git a/query-engine/connectors/sql-query-connector/src/query_builder/read.rs b/query-engine/connectors/sql-query-connector/src/query_builder/read.rs index a5385f1dd56a..3f73bb51b2d5 100644 --- a/query-engine/connectors/sql-query-connector/src/query_builder/read.rs +++ b/query-engine/connectors/sql-query-connector/src/query_builder/read.rs @@ -124,9 +124,7 @@ where T: SelectDefinition, { let (select, additional_selection_set) = query.into_select(model, aggr_selections, ctx); - let select = columns - .map(|c| c.set_is_selected(true)) - .fold(select, |acc, col| acc.column(col)); + let select = columns.fold(select, |acc, col| acc.column(col)); let select = select.append_trace(&Span::current()).add_trace_id(ctx.trace_id); @@ -176,7 +174,11 @@ pub(crate) fn aggregate( .append_trace(&Span::current()) .add_trace_id(ctx.trace_id), |select, next_op| match next_op { - AggregationSelection::Field(field) => select.column(Column::from(field.db_name().to_owned())), + AggregationSelection::Field(field) => select.column( + Column::from(field.db_name().to_owned()) + .set_is_enum(field.type_identifier().is_enum()) + .set_is_selected(true), + ), AggregationSelection::Count { all, fields } => { let select = fields.iter().fold(select, |select, next_field| { @@ -199,11 +201,15 @@ pub(crate) fn aggregate( }), AggregationSelection::Min(fields) => fields.iter().fold(select, |select, next_field| { - select.value(min(Column::from(next_field.db_name().to_owned()))) + select.value(min(Column::from(next_field.db_name().to_owned()) + .set_is_enum(next_field.type_identifier().is_enum()) + .set_is_selected(true))) }), AggregationSelection::Max(fields) => fields.iter().fold(select, |select, next_field| { - select.value(max(Column::from(next_field.db_name().to_owned()))) + select.value(max(Column::from(next_field.db_name().to_owned()) + .set_is_enum(next_field.type_identifier().is_enum()) + .set_is_selected(true))) }), }, ) @@ -243,11 +249,11 @@ pub(crate) fn group_by_aggregate( }), AggregationSelection::Min(fields) => fields.iter().fold(select, |select, next_field| { - select.value(min(next_field.as_column(ctx))) + select.value(min(next_field.as_column(ctx).set_is_selected(true))) }), AggregationSelection::Max(fields) => fields.iter().fold(select, |select, next_field| { - select.value(max(next_field.as_column(ctx))) + select.value(max(next_field.as_column(ctx).set_is_selected(true))) }), });