Skip to content

Commit

Permalink
Fix handling of grouped/ungrouped effects
Browse files Browse the repository at this point in the history
  • Loading branch information
milessabin committed Oct 5, 2023
1 parent 74b6b5d commit c5aefa6
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 60 deletions.
2 changes: 1 addition & 1 deletion modules/core/src/main/scala/compiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -902,7 +902,7 @@ object QueryCompiler {
} yield
emapping.get((ref, fieldName)) match {
case Some(handler) =>
Effect(handler, s.copy(child = ec))
Select(fieldName, resultName, Effect(handler, s.copy(child = ec)))
case None =>
s.copy(child = ec)
}
Expand Down
7 changes: 6 additions & 1 deletion modules/core/src/main/scala/query.scala
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ object Query {
}

trait EffectHandler[F[_]] {
def runEffects(queries: List[(Query, Cursor)]): F[Result[List[(Query, Cursor)]]]
def runEffects(queries: List[(Query, Cursor)]): F[Result[List[Cursor]]]
}

/** Evaluates an introspection query relative to `schema` */
Expand Down Expand Up @@ -261,6 +261,11 @@ object Query {
loop(q)
}

def childContext(c: Context, query: Query): Result[Context] =
rootName(query).toResultOrError(s"Query has the wrong shape").flatMap {
case (fieldName, resultName) => c.forField(fieldName, resultName)
}

/**
* Renames the root of `target` to match `source` if possible.
*/
Expand Down
34 changes: 10 additions & 24 deletions modules/core/src/main/scala/queryinterpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,11 @@ class QueryInterpreter[F[_]](mapping: Mapping[F]) {
else size(c0)
} yield List((sel.resultName, ProtoJson.fromJson(Json.fromInt(count))))

case (sel@Select(_, _, Effect(handler, cont)), _) =>
for {
value <- ProtoJson.effect(mapping, handler.asInstanceOf[EffectHandler[F]], cont, cursor).success
} yield List((sel.resultName, value))

case (sel@Select(fieldName, resultName, child), _) =>
val fieldTpe = tpe.field(fieldName).getOrElse(ScalarType.AttributeType)
for {
Expand All @@ -241,12 +246,6 @@ class QueryInterpreter[F[_]](mapping: Mapping[F]) {
value <- runValue(c, tpe, cursor)
} yield List((componentName, ProtoJson.select(value, componentName)))

case (e@Effect(_, cont), _) =>
for {
effectName <- resultName(cont).toResultOrError("Effect continuation has unexpected shape")
value <- runValue(e, tpe, cursor)
} yield List((effectName, value))

case (Group(siblings), _) =>
siblings.flatTraverse(query => runFields(query, tpe, cursor))

Expand Down Expand Up @@ -410,9 +409,6 @@ class QueryInterpreter[F[_]](mapping: Mapping[F]) {
} yield ProtoJson.component(mapping, renamedCont, cursor)
}

case (Effect(handler, cont), _) =>
ProtoJson.effect(mapping, handler.asInstanceOf[EffectHandler[F]], cont, cursor).success

case (Unique(child), _) =>
cursor.preunique.flatMap(c =>
runList(child, tpe.nonNull, c, true, tpe.isNullable)
Expand Down Expand Up @@ -629,19 +625,8 @@ object QueryInterpreter {
case p: Json => p
case d: DeferredJson => subst(d)
case ProtoObject(fields) =>
val newFields: Seq[(String, Json)] =
fields.flatMap { case (label, pvalue) =>
val value = loop(pvalue)
if (isDeferred(pvalue) && value.isObject) {
value.asObject.get.toList match {
case List((_, value)) => List((label, value))
case other => other
}
}
else List((label, value))
}
Json.fromFields(newFields)

val fields0 = fields.map { case (label, pvalue) => (label, loop(pvalue)) }
Json.fromFields(fields0)
case ProtoArray(elems) =>
val elems0 = elems.map(loop)
Json.fromValues(elems0)
Expand All @@ -667,8 +652,9 @@ object QueryInterpreter {
ResultT(mapping.combineAndRun(queries))
case Some(handler) =>
for {
conts <- ResultT(handler.runEffects(queries))
res <- ResultT(combineResults(conts.map { case (query, cursor) => mapping.interpreter.runValue(query, cursor.tpe, cursor) }).pure[F])
cs <- ResultT(handler.runEffects(queries))
conts <- ResultT(queries.traverse { case (q, _) => Query.extractChild(q).toResultOrError("Continuation query has the wrong shape") }.pure[F])
res <- ResultT(combineResults((conts, cs).parMapN { case (query, cursor) => mapping.interpreter.runValue(query, cursor.tpe, cursor) }).pure[F])
} yield res
}
next <- ResultT(completeAll[F](pnext))
Expand Down
30 changes: 16 additions & 14 deletions modules/sql/shared/src/main/scala/SqlMapping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2916,6 +2916,21 @@ trait SqlMappingLike[F[_]] extends CirceMappingLike[F] with SqlModule[F] { self

case _: Count => Result.internalError("Count node must be a child of a Select node")

case Select(fieldName, _, Effect(_, _)) =>
columnsForLeaf(context, fieldName).flatMap {
case Nil => EmptySqlQuery.success
case cols =>
val constraintCols = if(exposeJoins) parentConstraints.lastOption.getOrElse(Nil).map(_._2) else Nil
val extraCols = keyColumnsForType(context) ++ constraintCols
for {
parentTable <- parentTableForType(context)
extraJoins <- parentConstraintsToSqlJoins(parentTable, parentConstraints)
} yield
SqlSelect(context, Nil, parentTable, (cols ++ extraCols).distinct, extraJoins, Nil, Nil, None, None, Nil, true, false)
}

case Effect(_, _) => Result.internalError("Effect node must be a child of a Select node")

// Non-leaf non-Json element: compile subobject queries
case s@Select(fieldName, resultName, child) =>
context.forField(fieldName, resultName).flatMap { fieldContext =>
Expand Down Expand Up @@ -2943,19 +2958,6 @@ trait SqlMappingLike[F[_]] extends CirceMappingLike[F] with SqlModule[F] { self
}
}

case Effect(_, Select(fieldName, _, _)) =>
columnsForLeaf(context, fieldName).flatMap {
case Nil => EmptySqlQuery.success
case cols =>
val constraintCols = if(exposeJoins) parentConstraints.lastOption.getOrElse(Nil).map(_._2) else Nil
val extraCols = keyColumnsForType(context) ++ constraintCols
for {
parentTable <- parentTableForType(context)
extraJoins <- parentConstraintsToSqlJoins(parentTable, parentConstraints)
} yield
SqlSelect(context, Nil, parentTable, (cols ++ extraCols).distinct, extraJoins, Nil, Nil, None, None, Nil, true, false)
}

case TypeCase(default, narrows) =>
def isSimple(query: Query): Boolean = {
def loop(query: Query): Boolean =
Expand Down Expand Up @@ -3116,7 +3118,7 @@ trait SqlMappingLike[F[_]] extends CirceMappingLike[F] with SqlModule[F] { self
case TransformCursor(_, child) =>
loop(child, context, parentConstraints, exposeJoins)

case Empty | Query.Component(_, _, _) | Query.Effect(_, _) | (_: UntypedSelect) | (_: UntypedFragmentSpread) | (_: UntypedInlineFragment) | (_: Select) =>
case Empty | Query.Component(_, _, _) | (_: UntypedSelect) | (_: UntypedFragmentSpread) | (_: UntypedInlineFragment) | (_: Select) =>
EmptySqlQuery.success
}
}
Expand Down
58 changes: 39 additions & 19 deletions modules/sql/shared/src/test/scala/SqlNestedEffectsMapping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -199,14 +199,13 @@ trait SqlNestedEffectsMapping[F[_]] extends SqlTestMapping[F] {
)

object CurrencyQueryHandler extends EffectHandler[F] {
def runEffects(queries: List[(Query, Cursor)]): F[Result[List[(Query, Cursor)]]] = {
def runEffects(queries: List[(Query, Cursor)]): F[Result[List[Cursor]]] = {
val countryCodes = queries.map(_._2.fieldAs[String]("code2").toOption)
val distinctCodes = queries.flatMap(_._2.fieldAs[String]("code2").toList).distinct

val children = queries.flatMap {
case (Select(name, alias, child), parentCursor) =>
parentCursor.context.forField(name, alias).toList.map(ctx => (ctx, child, parentCursor))
case _ => Nil
val children0 = queries.traverse {
case (query, parentCursor) =>
Query.childContext(parentCursor.context, query).map(ctx => (ctx, parentCursor))
}

def unpackResults(res: Json): List[Json] =
Expand All @@ -225,39 +224,60 @@ trait SqlNestedEffectsMapping[F[_]] extends SqlTestMapping[F] {
case _ => Json.Null
}).getOrElse(Nil)

for {
res <- currencyService.get(distinctCodes)
(for {
children <- ResultT(children0.pure[F])
res <- ResultT(currencyService.get(distinctCodes).map(_.success))
} yield {
unpackResults(res).zip(children).map {
case (res, (ctx, child, parentCursor)) =>
val cursor = CirceCursor(ctx, res, Some(parentCursor), parentCursor.env)
(child, cursor)
case (res, (childContext, parentCursor)) =>
val cursor = CirceCursor(childContext, res, Some(parentCursor), parentCursor.env)
cursor
}
}.success
}).value.widen
}
}

object CountryQueryHandler extends EffectHandler[F] {
val toCode = Map("BR" -> "BRA", "GB" -> "GBR", "NL" -> "NLD")
def runEffects(queries: List[(Query, Cursor)]): F[Result[List[(Query, Cursor)]]] = {
def runEffects(queries: List[(Query, Cursor)]): F[Result[List[Cursor]]] = {

def mkListCursor(cursor: Cursor, fieldName: String, resultName: Option[String]): Result[Cursor] =
for {
c <- cursor.field(fieldName, resultName)
lc <- c.preunique
} yield lc

def extractCode(cursor: Cursor): Result[String] =
cursor.fieldAs[String]("code")

def partitionCursor(codes: List[String], cursor: Cursor): Result[List[Cursor]] = {
for {
cursors <- cursor.asList
tagged <- cursors.traverse(c => (extractCode(c).map { code => (code, c) }))
} yield {
val m = tagged.toMap
codes.map(code => m(code))
}
}

runGrouped(queries) {
case (Select("country", alias, child), cursors, indices) =>
case (Select(_, _, child), cursors, indices) =>
val codes = cursors.flatMap(_.fieldAs[Json]("countryCode").toOption.flatMap(_.asString).toList).map(toCode)
val combinedQuery = Select("country", alias, Filter(In(CountryType / "code", codes), child))
val combinedQuery = Select("country", None, Filter(In(CountryType / "code", codes), child))

(for {
cursor <- ResultT(sqlCursor(combinedQuery, Env.empty))
cursor <- ResultT(sqlCursor(combinedQuery, Env.empty))
cursor0 <- ResultT(mkListCursor(cursor, "country", None).pure[F])
pcs <- ResultT(partitionCursor(codes, cursor0).pure[F])
} yield {
codes.map { code =>
(Select("country", alias, Unique(Filter(Eql(CountryType / "code", Const(code)), child))), cursor)
}.zip(indices)
pcs.zip(indices)
}).value.widen

case _ => Result.internalError("Continuation query has the wrong shape").pure[F].widen
}
}

def runGrouped(ts: List[(Query, Cursor)])(op: (Query, List[Cursor], List[Int]) => F[Result[List[((Query, Cursor), Int)]]]): F[Result[List[(Query, Cursor)]]] = {
def runGrouped(ts: List[(Query, Cursor)])(op: (Query, List[Cursor], List[Int]) => F[Result[List[(Cursor, Int)]]]): F[Result[List[Cursor]]] = {
val groupedAndIndexed = ts.zipWithIndex.groupMap(_._1._1)(ti => (ti._1._2, ti._2)).toList
val groupedResults =
groupedAndIndexed.map { case (q, cis) =>
Expand Down
34 changes: 33 additions & 1 deletion modules/sql/shared/src/test/scala/SqlNestedEffectsSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,39 @@ trait SqlNestedEffectsSuite extends CatsEffectSuite {
assertWeaklyEqualIO(res, expected)
}

test("simple composed query") {
test("simple composed query (1)") {
val query = """
query {
country(code: "GBR") {
currencies {
code
exchangeRate
}
}
}
"""

val expected = json"""
{
"data" : {
"country" : {
"currencies": [
{
"code": "GBP",
"exchangeRate": 1.25
}
]
}
}
}
"""

val res = mapping.flatMap(_._2.compileAndRun(query))

assertWeaklyEqualIO(res, expected)
}

test("simple composed query (2)") {
val query = """
query {
country(code: "GBR") {
Expand Down

0 comments on commit c5aefa6

Please sign in to comment.