From c5aefa6f4c510887c04d75efe3f1f15d015c1709 Mon Sep 17 00:00:00 2001 From: Miles Sabin Date: Thu, 5 Oct 2023 15:44:30 +0100 Subject: [PATCH] Fix handling of grouped/ungrouped effects --- modules/core/src/main/scala/compiler.scala | 2 +- modules/core/src/main/scala/query.scala | 7 ++- .../src/main/scala/queryinterpreter.scala | 34 ++++------- .../shared/src/main/scala/SqlMapping.scala | 30 +++++----- .../test/scala/SqlNestedEffectsMapping.scala | 58 +++++++++++++------ .../test/scala/SqlNestedEffectsSuite.scala | 34 ++++++++++- 6 files changed, 105 insertions(+), 60 deletions(-) diff --git a/modules/core/src/main/scala/compiler.scala b/modules/core/src/main/scala/compiler.scala index d6ef3d3b..b234f56d 100644 --- a/modules/core/src/main/scala/compiler.scala +++ b/modules/core/src/main/scala/compiler.scala @@ -902,7 +902,7 @@ object QueryCompiler { } yield emapping.get((ref, fieldName)) match { case Some(handler) => - Effect(handler, s.copy(child = ec)) + Select(fieldName, resultName, Effect(handler, s.copy(child = ec))) case None => s.copy(child = ec) } diff --git a/modules/core/src/main/scala/query.scala b/modules/core/src/main/scala/query.scala index 7c4762e2..2819bc7c 100644 --- a/modules/core/src/main/scala/query.scala +++ b/modules/core/src/main/scala/query.scala @@ -93,7 +93,7 @@ object Query { } trait EffectHandler[F[_]] { - def runEffects(queries: List[(Query, Cursor)]): F[Result[List[(Query, Cursor)]]] + def runEffects(queries: List[(Query, Cursor)]): F[Result[List[Cursor]]] } /** Evaluates an introspection query relative to `schema` */ @@ -261,6 +261,11 @@ object Query { loop(q) } + def childContext(c: Context, query: Query): Result[Context] = + rootName(query).toResultOrError(s"Query has the wrong shape").flatMap { + case (fieldName, resultName) => c.forField(fieldName, resultName) + } + /** * Renames the root of `target` to match `source` if possible. */ diff --git a/modules/core/src/main/scala/queryinterpreter.scala b/modules/core/src/main/scala/queryinterpreter.scala index 1f15599a..5831523c 100644 --- a/modules/core/src/main/scala/queryinterpreter.scala +++ b/modules/core/src/main/scala/queryinterpreter.scala @@ -228,6 +228,11 @@ class QueryInterpreter[F[_]](mapping: Mapping[F]) { else size(c0) } yield List((sel.resultName, ProtoJson.fromJson(Json.fromInt(count)))) + case (sel@Select(_, _, Effect(handler, cont)), _) => + for { + value <- ProtoJson.effect(mapping, handler.asInstanceOf[EffectHandler[F]], cont, cursor).success + } yield List((sel.resultName, value)) + case (sel@Select(fieldName, resultName, child), _) => val fieldTpe = tpe.field(fieldName).getOrElse(ScalarType.AttributeType) for { @@ -241,12 +246,6 @@ class QueryInterpreter[F[_]](mapping: Mapping[F]) { value <- runValue(c, tpe, cursor) } yield List((componentName, ProtoJson.select(value, componentName))) - case (e@Effect(_, cont), _) => - for { - effectName <- resultName(cont).toResultOrError("Effect continuation has unexpected shape") - value <- runValue(e, tpe, cursor) - } yield List((effectName, value)) - case (Group(siblings), _) => siblings.flatTraverse(query => runFields(query, tpe, cursor)) @@ -410,9 +409,6 @@ class QueryInterpreter[F[_]](mapping: Mapping[F]) { } yield ProtoJson.component(mapping, renamedCont, cursor) } - case (Effect(handler, cont), _) => - ProtoJson.effect(mapping, handler.asInstanceOf[EffectHandler[F]], cont, cursor).success - case (Unique(child), _) => cursor.preunique.flatMap(c => runList(child, tpe.nonNull, c, true, tpe.isNullable) @@ -629,19 +625,8 @@ object QueryInterpreter { case p: Json => p case d: DeferredJson => subst(d) case ProtoObject(fields) => - val newFields: Seq[(String, Json)] = - fields.flatMap { case (label, pvalue) => - val value = loop(pvalue) - if (isDeferred(pvalue) && value.isObject) { - value.asObject.get.toList match { - case List((_, value)) => List((label, value)) - case other => other - } - } - else List((label, value)) - } - Json.fromFields(newFields) - + val fields0 = fields.map { case (label, pvalue) => (label, loop(pvalue)) } + Json.fromFields(fields0) case ProtoArray(elems) => val elems0 = elems.map(loop) Json.fromValues(elems0) @@ -667,8 +652,9 @@ object QueryInterpreter { ResultT(mapping.combineAndRun(queries)) case Some(handler) => for { - conts <- ResultT(handler.runEffects(queries)) - res <- ResultT(combineResults(conts.map { case (query, cursor) => mapping.interpreter.runValue(query, cursor.tpe, cursor) }).pure[F]) + cs <- ResultT(handler.runEffects(queries)) + conts <- ResultT(queries.traverse { case (q, _) => Query.extractChild(q).toResultOrError("Continuation query has the wrong shape") }.pure[F]) + res <- ResultT(combineResults((conts, cs).parMapN { case (query, cursor) => mapping.interpreter.runValue(query, cursor.tpe, cursor) }).pure[F]) } yield res } next <- ResultT(completeAll[F](pnext)) diff --git a/modules/sql/shared/src/main/scala/SqlMapping.scala b/modules/sql/shared/src/main/scala/SqlMapping.scala index 51002ea5..fd2381de 100644 --- a/modules/sql/shared/src/main/scala/SqlMapping.scala +++ b/modules/sql/shared/src/main/scala/SqlMapping.scala @@ -2916,6 +2916,21 @@ trait SqlMappingLike[F[_]] extends CirceMappingLike[F] with SqlModule[F] { self case _: Count => Result.internalError("Count node must be a child of a Select node") + case Select(fieldName, _, Effect(_, _)) => + columnsForLeaf(context, fieldName).flatMap { + case Nil => EmptySqlQuery.success + case cols => + val constraintCols = if(exposeJoins) parentConstraints.lastOption.getOrElse(Nil).map(_._2) else Nil + val extraCols = keyColumnsForType(context) ++ constraintCols + for { + parentTable <- parentTableForType(context) + extraJoins <- parentConstraintsToSqlJoins(parentTable, parentConstraints) + } yield + SqlSelect(context, Nil, parentTable, (cols ++ extraCols).distinct, extraJoins, Nil, Nil, None, None, Nil, true, false) + } + + case Effect(_, _) => Result.internalError("Effect node must be a child of a Select node") + // Non-leaf non-Json element: compile subobject queries case s@Select(fieldName, resultName, child) => context.forField(fieldName, resultName).flatMap { fieldContext => @@ -2943,19 +2958,6 @@ trait SqlMappingLike[F[_]] extends CirceMappingLike[F] with SqlModule[F] { self } } - case Effect(_, Select(fieldName, _, _)) => - columnsForLeaf(context, fieldName).flatMap { - case Nil => EmptySqlQuery.success - case cols => - val constraintCols = if(exposeJoins) parentConstraints.lastOption.getOrElse(Nil).map(_._2) else Nil - val extraCols = keyColumnsForType(context) ++ constraintCols - for { - parentTable <- parentTableForType(context) - extraJoins <- parentConstraintsToSqlJoins(parentTable, parentConstraints) - } yield - SqlSelect(context, Nil, parentTable, (cols ++ extraCols).distinct, extraJoins, Nil, Nil, None, None, Nil, true, false) - } - case TypeCase(default, narrows) => def isSimple(query: Query): Boolean = { def loop(query: Query): Boolean = @@ -3116,7 +3118,7 @@ trait SqlMappingLike[F[_]] extends CirceMappingLike[F] with SqlModule[F] { self case TransformCursor(_, child) => loop(child, context, parentConstraints, exposeJoins) - case Empty | Query.Component(_, _, _) | Query.Effect(_, _) | (_: UntypedSelect) | (_: UntypedFragmentSpread) | (_: UntypedInlineFragment) | (_: Select) => + case Empty | Query.Component(_, _, _) | (_: UntypedSelect) | (_: UntypedFragmentSpread) | (_: UntypedInlineFragment) | (_: Select) => EmptySqlQuery.success } } diff --git a/modules/sql/shared/src/test/scala/SqlNestedEffectsMapping.scala b/modules/sql/shared/src/test/scala/SqlNestedEffectsMapping.scala index 0e45a292..545102a7 100644 --- a/modules/sql/shared/src/test/scala/SqlNestedEffectsMapping.scala +++ b/modules/sql/shared/src/test/scala/SqlNestedEffectsMapping.scala @@ -199,14 +199,13 @@ trait SqlNestedEffectsMapping[F[_]] extends SqlTestMapping[F] { ) object CurrencyQueryHandler extends EffectHandler[F] { - def runEffects(queries: List[(Query, Cursor)]): F[Result[List[(Query, Cursor)]]] = { + def runEffects(queries: List[(Query, Cursor)]): F[Result[List[Cursor]]] = { val countryCodes = queries.map(_._2.fieldAs[String]("code2").toOption) val distinctCodes = queries.flatMap(_._2.fieldAs[String]("code2").toList).distinct - val children = queries.flatMap { - case (Select(name, alias, child), parentCursor) => - parentCursor.context.forField(name, alias).toList.map(ctx => (ctx, child, parentCursor)) - case _ => Nil + val children0 = queries.traverse { + case (query, parentCursor) => + Query.childContext(parentCursor.context, query).map(ctx => (ctx, parentCursor)) } def unpackResults(res: Json): List[Json] = @@ -225,39 +224,60 @@ trait SqlNestedEffectsMapping[F[_]] extends SqlTestMapping[F] { case _ => Json.Null }).getOrElse(Nil) - for { - res <- currencyService.get(distinctCodes) + (for { + children <- ResultT(children0.pure[F]) + res <- ResultT(currencyService.get(distinctCodes).map(_.success)) } yield { unpackResults(res).zip(children).map { - case (res, (ctx, child, parentCursor)) => - val cursor = CirceCursor(ctx, res, Some(parentCursor), parentCursor.env) - (child, cursor) + case (res, (childContext, parentCursor)) => + val cursor = CirceCursor(childContext, res, Some(parentCursor), parentCursor.env) + cursor } - }.success + }).value.widen } } object CountryQueryHandler extends EffectHandler[F] { val toCode = Map("BR" -> "BRA", "GB" -> "GBR", "NL" -> "NLD") - def runEffects(queries: List[(Query, Cursor)]): F[Result[List[(Query, Cursor)]]] = { + def runEffects(queries: List[(Query, Cursor)]): F[Result[List[Cursor]]] = { + + def mkListCursor(cursor: Cursor, fieldName: String, resultName: Option[String]): Result[Cursor] = + for { + c <- cursor.field(fieldName, resultName) + lc <- c.preunique + } yield lc + + def extractCode(cursor: Cursor): Result[String] = + cursor.fieldAs[String]("code") + + def partitionCursor(codes: List[String], cursor: Cursor): Result[List[Cursor]] = { + for { + cursors <- cursor.asList + tagged <- cursors.traverse(c => (extractCode(c).map { code => (code, c) })) + } yield { + val m = tagged.toMap + codes.map(code => m(code)) + } + } + runGrouped(queries) { - case (Select("country", alias, child), cursors, indices) => + case (Select(_, _, child), cursors, indices) => val codes = cursors.flatMap(_.fieldAs[Json]("countryCode").toOption.flatMap(_.asString).toList).map(toCode) - val combinedQuery = Select("country", alias, Filter(In(CountryType / "code", codes), child)) + val combinedQuery = Select("country", None, Filter(In(CountryType / "code", codes), child)) (for { - cursor <- ResultT(sqlCursor(combinedQuery, Env.empty)) + cursor <- ResultT(sqlCursor(combinedQuery, Env.empty)) + cursor0 <- ResultT(mkListCursor(cursor, "country", None).pure[F]) + pcs <- ResultT(partitionCursor(codes, cursor0).pure[F]) } yield { - codes.map { code => - (Select("country", alias, Unique(Filter(Eql(CountryType / "code", Const(code)), child))), cursor) - }.zip(indices) + pcs.zip(indices) }).value.widen case _ => Result.internalError("Continuation query has the wrong shape").pure[F].widen } } - def runGrouped(ts: List[(Query, Cursor)])(op: (Query, List[Cursor], List[Int]) => F[Result[List[((Query, Cursor), Int)]]]): F[Result[List[(Query, Cursor)]]] = { + def runGrouped(ts: List[(Query, Cursor)])(op: (Query, List[Cursor], List[Int]) => F[Result[List[(Cursor, Int)]]]): F[Result[List[Cursor]]] = { val groupedAndIndexed = ts.zipWithIndex.groupMap(_._1._1)(ti => (ti._1._2, ti._2)).toList val groupedResults = groupedAndIndexed.map { case (q, cis) => diff --git a/modules/sql/shared/src/test/scala/SqlNestedEffectsSuite.scala b/modules/sql/shared/src/test/scala/SqlNestedEffectsSuite.scala index 9c2d0d06..fb4e119d 100644 --- a/modules/sql/shared/src/test/scala/SqlNestedEffectsSuite.scala +++ b/modules/sql/shared/src/test/scala/SqlNestedEffectsSuite.scala @@ -35,7 +35,39 @@ trait SqlNestedEffectsSuite extends CatsEffectSuite { assertWeaklyEqualIO(res, expected) } - test("simple composed query") { + test("simple composed query (1)") { + val query = """ + query { + country(code: "GBR") { + currencies { + code + exchangeRate + } + } + } + """ + + val expected = json""" + { + "data" : { + "country" : { + "currencies": [ + { + "code": "GBP", + "exchangeRate": 1.25 + } + ] + } + } + } + """ + + val res = mapping.flatMap(_._2.compileAndRun(query)) + + assertWeaklyEqualIO(res, expected) + } + + test("simple composed query (2)") { val query = """ query { country(code: "GBR") {