From f3ecca3c00dfec578916d20d80d0e7817771c727 Mon Sep 17 00:00:00 2001 From: Miles Sabin Date: Fri, 5 Aug 2022 15:40:04 +0100 Subject: [PATCH 1/2] Tweaked benchmark --- profile/src/main/scala/Bench.scala | 32 ++++++++++++++++++++---------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/profile/src/main/scala/Bench.scala b/profile/src/main/scala/Bench.scala index d10aaadb..c659fe68 100644 --- a/profile/src/main/scala/Bench.scala +++ b/profile/src/main/scala/Bench.scala @@ -249,28 +249,38 @@ object Bench extends IOApp { val mapping = WorldMapping.mkMapping(xa) - def runQuery: IO[Unit] = { - val query = """ - query { - cities { + val query = """ + query { + cities { + name + country { name - country { + cities { name } } } - """ + } + """ - mapping.compileAndRun(query).void + def runQuery(show: Boolean, time: Boolean): IO[Unit] = { + for { + st <- if(time) IO.realTime else IO.pure(0.millis) + res <- mapping.compileAndRun(query) + et <- if(time) IO.realTime else IO.pure(0.millis) + _ <- IO.println(s"Execution time: ${et-st}").whenA(time) + _ <- IO.println(res).whenA(show) + } yield () } def run(args: List[String]) = { for { + _ <- IO.println("-----------------") _ <- IO.println("Warmup ...") - _ <- runQuery.replicateA(5) - _ <- IO.sleep(5.seconds) - _ <- IO.println("Running query ...") - _ <- runQuery.replicateA(100) + _ <- runQuery(false, true).replicateA(100) + _ <- IO.println("-----------------") + _ <- IO.println("Executing query ...") + _ <- runQuery(false, true).replicateA(1000) } yield ExitCode.Success } } From c806e698f3be09ab293f31d5f9644630cfdefed5 Mon Sep 17 00:00:00 2001 From: Miles Sabin Date: Fri, 5 Aug 2022 15:59:20 +0100 Subject: [PATCH 2/2] Reduce use of Lists and miscellaneous performance tweaks + Uses of List replaced with Seq, Iterable or Iterator where possible. + Cursor asList now takes a Factory to avoid committing to List. + Context path is more efficiently represented. + Several internal repeated List traversals and transforms replaced by hidden imperative loops for performance. + SqlColumns now record their result path which is used as a key for column lookup during result contruction, eliminating many repeated linear traversals of the mappings. + Key columns and leaf encoders are memoized during result construction. --- build.sbt | 2 +- .../circe/src/main/scala/circemapping.scala | 6 +- modules/core/src/main/scala/cursor.scala | 37 ++- modules/core/src/main/scala/query.scala | 2 +- .../src/main/scala/queryinterpreter.scala | 156 +++++----- .../core/src/main/scala/valuemapping.scala | 5 +- .../src/main/scala/genericmapping.scala | 8 +- .../generic/src/test/scala/derivation.scala | 4 +- modules/sql/src/main/scala/SqlMapping.scala | 286 +++++++++++------- 9 files changed, 287 insertions(+), 219 deletions(-) diff --git a/build.sbt b/build.sbt index 4410da5b..85ac1dcc 100644 --- a/build.sbt +++ b/build.sbt @@ -26,7 +26,7 @@ val Scala3 = "3.1.3" ThisBuild / scalaVersion := Scala2 ThisBuild / crossScalaVersions := Seq(Scala2, Scala3) -ThisBuild / tlBaseVersion := "0.3" +ThisBuild / tlBaseVersion := "0.4" ThisBuild / organization := "edu.gemini" ThisBuild / organizationName := "Association of Universities for Research in Astronomy, Inc. (AURA)" ThisBuild / startYear := Some(2019) diff --git a/modules/circe/src/main/scala/circemapping.scala b/modules/circe/src/main/scala/circemapping.scala index cbc4d63d..a59076d3 100644 --- a/modules/circe/src/main/scala/circemapping.scala +++ b/modules/circe/src/main/scala/circemapping.scala @@ -4,6 +4,8 @@ package edu.gemini.grackle package circe +import scala.collection.Factory + import cats.Monad import cats.implicits._ import fs2.Stream @@ -78,9 +80,9 @@ abstract class CirceMapping[F[_]: Monad] extends Mapping[F] { def isList: Boolean = tpe.isList && focus.isArray - def asList: Result[List[Cursor]] = tpe match { + def asList[C](factory: Factory[Cursor, C]): Result[C] = tpe match { case ListType(elemTpe) if focus.isArray => - focus.asArray.map(_.map(e => mkChild(context.asType(elemTpe), e)).toList) + focus.asArray.map(_.view.map(e => mkChild(context.asType(elemTpe), e)).to(factory)) .toRightIor(mkOneError(s"Expected List type, found $tpe for focus ${focus.noSpaces}")) case _ => mkErrorResult(s"Expected List type, found $tpe for focus ${focus.noSpaces}") diff --git a/modules/core/src/main/scala/cursor.scala b/modules/core/src/main/scala/cursor.scala index 1352b6e4..cedda111 100644 --- a/modules/core/src/main/scala/cursor.scala +++ b/modules/core/src/main/scala/cursor.scala @@ -3,6 +3,7 @@ package edu.gemini.grackle +import scala.collection.Factory import scala.reflect.{classTag, ClassTag} import cats.data.Ior @@ -61,7 +62,14 @@ trait Cursor { * this `Cursor` if it is of a list type, or an error or the left hand side * otherwise. */ - def asList: Result[List[Cursor]] + final def asList: Result[List[Cursor]] = asList(List) + + /** + * Yield a collection of `Cursor`s corresponding to the elements of the value at + * this `Cursor` if it is of a list type, or an error or the left hand side + * otherwise. + */ + def asList[C](factory: Factory[Cursor, C]): Result[C] /** Is the value at this `Cursor` of a nullable type? */ def isNullable: Boolean @@ -231,28 +239,27 @@ object Cursor { */ case class Context( rootTpe: Type, - path0: List[(String, String, Type)] = Nil + path: List[String], + resultPath: List[String], + typePath: List[Type] ) { - lazy val path: List[String] = path0.map(_._1) - lazy val resultPath: List[String] = path0.map(_._2) - lazy val typePath = path0.map(_._3) - lazy val tpe: Type = path0.headOption.map(_._3).getOrElse(rootTpe) + lazy val tpe: Type = typePath.headOption.getOrElse(rootTpe) def asType(tpe: Type): Context = { - path0 match { + typePath match { case Nil => copy(rootTpe = tpe) - case hd :: tl => copy(path0 = (hd._1, hd._2, tpe) :: tl) + case _ :: tl => copy(typePath = tpe :: tl) } } def forField(fieldName: String, resultName: String): Option[Context] = tpe.underlyingField(fieldName).map { fieldTpe => - copy(path0 = (fieldName, resultName, fieldTpe) :: path0) + copy(path = fieldName :: path, resultPath = resultName :: resultPath, typePath = fieldTpe :: typePath) } def forField(fieldName: String, resultName: Option[String]): Option[Context] = tpe.underlyingField(fieldName).map { fieldTpe => - copy(path0 = (fieldName, resultName.getOrElse(fieldName), fieldTpe) :: path0) + copy(path = fieldName :: path, resultPath = resultName.getOrElse(fieldName) :: resultPath, typePath = fieldTpe :: typePath) } def forPath(path1: List[String]): Option[Context] = @@ -263,13 +270,13 @@ object Cursor { def forFieldOrAttribute(fieldName: String, resultName: Option[String]): Context = { val fieldTpe = tpe.underlyingField(fieldName).getOrElse(ScalarType.AttributeType) - copy(path0 = (fieldName, resultName.getOrElse(fieldName), fieldTpe) :: path0) + copy(path = fieldName :: path, resultPath = resultName.getOrElse(fieldName) :: resultPath, typePath = fieldTpe :: typePath) } override def equals(other: Any): Boolean = other match { - case Context(oRootTpe, oPath0) => - rootTpe =:= oRootTpe && path0.corresponds(oPath0)((x, y) => x._1 == y._1 && x._2 == y._2) + case Context(oRootTpe, oPath, oResultPath, _) => + rootTpe =:= oRootTpe && resultPath == oResultPath && path == oPath case _ => false } @@ -280,10 +287,10 @@ object Cursor { def apply(rootTpe: Type, fieldName: String, resultName: Option[String]): Option[Context] = { for { tpe <- rootTpe.underlyingField(fieldName) - } yield new Context(rootTpe, List((fieldName, resultName.getOrElse(fieldName), tpe))) + } yield new Context(rootTpe, List(fieldName), List(resultName.getOrElse(fieldName)), List(tpe)) } - def apply(rootTpe: Type): Context = Context(rootTpe, Nil) + def apply(rootTpe: Type): Context = Context(rootTpe, Nil, Nil, Nil) } def flatten(c: Cursor): Result[List[Cursor]] = diff --git a/modules/core/src/main/scala/query.scala b/modules/core/src/main/scala/query.scala index d29bdeca..04800926 100644 --- a/modules/core/src/main/scala/query.scala +++ b/modules/core/src/main/scala/query.scala @@ -139,7 +139,7 @@ object Query { } case class OrderSelections(selections: List[OrderSelection[_]]) { - def order(lc: List[Cursor]): List[Cursor] = { + def order(lc: Seq[Cursor]): Seq[Cursor] = { def cmp(x: Cursor, y: Cursor): Int = { @tailrec def loop(sels: List[OrderSelection[_]]): Int = diff --git a/modules/core/src/main/scala/queryinterpreter.scala b/modules/core/src/main/scala/queryinterpreter.scala index 4c5fb25b..433bcab4 100644 --- a/modules/core/src/main/scala/queryinterpreter.scala +++ b/modules/core/src/main/scala/queryinterpreter.scala @@ -8,7 +8,7 @@ import scala.collection.mutable import scala.jdk.CollectionConverters._ import cats.Monoid -import cats.data.{ Chain, Ior, IorT, Kleisli, NonEmptyChain } +import cats.data.{ Chain, Ior, IorT, NonEmptyChain } import cats.implicits._ import fs2.Stream import io.circe.Json @@ -235,10 +235,10 @@ class QueryInterpreter[F[_]](mapping: Mapping[F]) { c0.asNullable.flatMap { case None => 0.rightIor case Some(c1) => - if (c1.isList) c1.asList.map(c2 => c2.size) + if (c1.isList) c1.asList(Iterator).map(_.size) else 1.rightIor } - else if (c0.isList) c0.asList.map(c2 => c2.size) + else if (c0.isList) c0.asList(Iterator).map(_.size) else 1.rightIor }.map { value => List((fieldName, ProtoJson.fromJson(Json.fromInt(value)))) } @@ -253,29 +253,61 @@ class QueryInterpreter[F[_]](mapping: Mapping[F]) { } } - def runList(query: Query, tpe: Type, cursors: List[Cursor], f: Kleisli[Result, List[Cursor], List[Cursor]]): Result[ProtoJson] = - if (cursors.exists(cursor => !cursorCompatible(tpe, cursor.tpe))) - mkErrorResult(s"Mismatched query and cursor type in runList: $tpe ${cursors.map(_.tpe)}") - else { + def runList(query: Query, tpe: Type, cursors: Iterator[Cursor], unique: Boolean, nullable: Boolean): Result[ProtoJson] = { + val (child, ic) = query match { - case Filter(pred, child) => - runList(child, tpe, cursors, f.compose(_.filterA(pred(_)))) - - case Limit(num, child) => - runList(child, tpe, cursors, f.compose(_.take(num).rightIor)) + case FilterOrderByOffsetLimit(pred, selections, offset, limit, child) => + val filtered = + pred.map { p => + cursors.filter { c => + p(c) match { + case left@Ior.Left(_) => return left + case Ior.Right(c) => c + case Ior.Both(_, c) => c + } + } + }.getOrElse(cursors) + val sorted = selections.map(OrderSelections(_).order(filtered.toSeq).iterator).getOrElse(filtered) + val sliced = (offset, limit) match { + case (None, None) => sorted + case (Some(off), None) => sorted.drop(off) + case (None, Some(lim)) => sorted.take(lim) + case (Some(off), Some(lim)) => sorted.slice(off, off+lim) + } + (child, sliced) + case other => (other, cursors) + } - case Offset(num, child) => - runList(child, tpe, cursors, f.compose(_.drop(num).rightIor)) + val builder = Vector.newBuilder[ProtoJson] + var problems = Chain.empty[Problem] + builder.sizeHint(ic.knownSize) + while(ic.hasNext) { + val c = ic.next() + if (!cursorCompatible(tpe, c.tpe)) + return mkErrorResult(s"Mismatched query and cursor type in runList: $tpe ${cursors.map(_.tpe)}") + + runValue(child, tpe, c) match { + case left@Ior.Left(_) => return left + case Ior.Right(v) => builder.addOne(v) + case Ior.Both(ps, v) => + builder.addOne(v) + problems = problems.concat(ps.toChain) + } + } - case OrderBy(selections, child) => - runList(child, tpe, cursors, f.compose(selections.order(_).rightIor)) + def mkResult(j: ProtoJson): Result[ProtoJson] = + NonEmptyChain.fromChain(problems).map(neps => Ior.Both(neps, j)).getOrElse(j.rightIor) - case _ => - f.run(cursors).flatMap(lc => - lc.traverse(c => runValue(query, tpe, c)).map(ProtoJson.fromValues) - ) - } + if (!unique) mkResult(ProtoJson.fromValues(builder.result())) + else { + val size = builder.knownSize + if (size == 1) mkResult(builder.result()(0)) + else if (size == 0) { + if(nullable) mkResult(ProtoJson.fromJson(Json.Null)) + else mkErrorResult(s"No match") + } else mkErrorResult(s"Multiple matches") } + } /** * Interpret `query` against `cursor` with expected type `tpe`. @@ -298,9 +330,18 @@ class QueryInterpreter[F[_]](mapping: Mapping[F]) { case (Wrap(_, Component(_, _, _)), ListType(tpe)) => // Keep the wrapper with the component when going under the list - cursor.asList.flatMap(lc => - lc.traverse(c => runValue(query, tpe, c)).map(ProtoJson.fromValues) - ) + cursor.asList(Iterator).map { ic => + val builder = Vector.newBuilder[ProtoJson] + builder.sizeHint(ic.knownSize) + while(ic.hasNext) { + val c = ic.next() + runValue(query, tpe, c) match { + case Ior.Right(v) => builder.addOne(v) + case left => return left + } + } + ProtoJson.fromValues(builder.result()) + } case (Wrap(_, Defer(_, _, _)), _) if cursor.isNull => ProtoJson.fromJson(Json.Null).rightIor @@ -346,24 +387,6 @@ class QueryInterpreter[F[_]](mapping: Mapping[F]) { } else stage(cursor) - case (Unique(Filter(pred, child)), _) => - val cursors = - if (cursor.isNullable) - cursor.asNullable.flatMap { - case None => Nil.rightIor - case Some(c) => c.asList - } - else cursor.asList - - cursors.flatMap(_.filterA(pred(_))).flatMap(lc => - lc match { - case List(c) => runValue(child, tpe.nonNull, c) - case Nil if tpe.isNullable => ProtoJson.fromJson(Json.Null).rightIor - case Nil => mkErrorResult(s"No match") - case _ => mkErrorResult(s"Multiple matches") - } - ) - case (Unique(child), _) => val oc = if (cursor.isNullable) cursor.asNullable @@ -371,21 +394,16 @@ class QueryInterpreter[F[_]](mapping: Mapping[F]) { oc.flatMap { case Some(c) => - runValue(child, tpe.nonNull.list, c).flatMap { pj => - ProtoJson.unpackList(pj).map { - case Nil if tpe.isNullable => ProtoJson.fromJson(Json.Null).rightIor - case Nil => mkErrorResult(s"No match") - case List(elem) => elem.rightIor - case _ => mkErrorResult(s"Multiple matches") - }.getOrElse(mkErrorResult(s"Unique result of the wrong shape: $pj")) + c.asList(Iterator).flatMap { cursors => + runList(child, tpe.nonNull, cursors, true, tpe.isNullable) } case None => ProtoJson.fromJson(Json.Null).rightIor } case (_, ListType(tpe)) => - cursor.asList.flatMap { cursors => - runList(query, tpe, cursors, Kleisli(_.rightIor)) + cursor.asList(Iterator).flatMap { cursors => + runList(query, tpe, cursors, false, false) } case (_, NullableType(tpe)) => @@ -423,9 +441,9 @@ object QueryInterpreter { // A result which is deferred to the next stage or component of this interpreter. private[QueryInterpreter] case class StagedJson[F[_]](interpreter: QueryInterpreter[F], query: Query, rootTpe: Type, env: Env) extends DeferredJson // A partially constructed object which has at least one deferred subtree. - private[QueryInterpreter] case class ProtoObject(fields: List[(String, ProtoJson)]) + private[QueryInterpreter] case class ProtoObject(fields: Seq[(String, ProtoJson)]) // A partially constructed array which has at least one deferred element. - private[QueryInterpreter] case class ProtoArray(elems: List[ProtoJson]) + private[QueryInterpreter] case class ProtoArray(elems: Seq[ProtoJson]) // A result which will yield a selection from its child private[QueryInterpreter] case class ProtoSelect(elem: ProtoJson, fieldName: String) @@ -444,9 +462,9 @@ object QueryInterpreter { * If all fields are complete then they will be combined as a complete * Json object. */ - def fromFields(fields: List[(String, ProtoJson)]): ProtoJson = + def fromFields(fields: Seq[(String, ProtoJson)]): ProtoJson = if(fields.forall(_._2.isInstanceOf[Json])) - wrap(Json.fromFields(fields.asInstanceOf[List[(String, Json)]])) + wrap(Json.fromFields(fields.asInstanceOf[Seq[(String, Json)]])) else wrap(ProtoObject(fields)) @@ -456,9 +474,9 @@ object QueryInterpreter { * If all values are complete then they will be combined as a complete * Json array. */ - def fromValues(elems: List[ProtoJson]): ProtoJson = + def fromValues(elems: Seq[ProtoJson]): ProtoJson = if(elems.forall(_.isInstanceOf[Json])) - wrap(Json.fromValues(elems.asInstanceOf[List[Json]])) + wrap(Json.fromValues(elems.asInstanceOf[Seq[Json]])) else wrap(ProtoArray(elems)) @@ -508,29 +526,7 @@ object QueryInterpreter { } } - def unpackObject(p: ProtoJson): Option[List[ProtoJson]] = - p match { - case ProtoObject(List((_, packedElems))) => unpackList(packedElems) - case j: Json if j.isObject => - j.asObject.flatMap(jo => - if (jo.size != 1) None - else { - val List((_, packedElems)) = jo.toList - packedElems.asArray.map(v => wrapList(v.toList)) - } - ) - case _ => None - } - - def unpackList(p: ProtoJson): Option[List[ProtoJson]] = - p match { - case ProtoArray(elems) => Some(elems) - case j: Json if j.isArray => j.asArray.map(ja => wrapList(ja.toList)) - case _ => None - } - private def wrap(j: AnyRef): ProtoJson = j.asInstanceOf[ProtoJson] - private def wrapList(l: List[AnyRef]): List[ProtoJson] = l.asInstanceOf[List[ProtoJson]] } import ProtoJson._ @@ -600,7 +596,7 @@ object QueryInterpreter { case p: Json => p case d: DeferredJson => subst(d) case ProtoObject(fields) => - val newFields: List[(String, Json)] = + val newFields: Seq[(String, Json)] = fields.flatMap { case (label, pvalue) => val value = loop(pvalue) if (isDeferred(pvalue) && value.isObject) { diff --git a/modules/core/src/main/scala/valuemapping.scala b/modules/core/src/main/scala/valuemapping.scala index 5213c3a4..b8363282 100644 --- a/modules/core/src/main/scala/valuemapping.scala +++ b/modules/core/src/main/scala/valuemapping.scala @@ -3,6 +3,7 @@ package edu.gemini.grackle +import scala.collection.Factory import scala.reflect.ClassTag import cats.Monad @@ -118,8 +119,8 @@ abstract class ValueMapping[F[_]: Monad] extends Mapping[F] { case _ => false } - def asList: Result[List[Cursor]] = (tpe, focus) match { - case (ListType(tpe), it: List[_]) => it.map(f => mkChild(context.asType(tpe), f)).rightIor + def asList[C](factory: Factory[Cursor, C]): Result[C] = (tpe, focus) match { + case (ListType(tpe), it: List[_]) => it.view.map(f => mkChild(context.asType(tpe), f)).to(factory).rightIor case _ => mkErrorResult(s"Expected List type, found $tpe") } diff --git a/modules/generic/src/main/scala/genericmapping.scala b/modules/generic/src/main/scala/genericmapping.scala index b4ba6697..a121e0bd 100644 --- a/modules/generic/src/main/scala/genericmapping.scala +++ b/modules/generic/src/main/scala/genericmapping.scala @@ -4,6 +4,8 @@ package edu.gemini.grackle package generic +import scala.collection.Factory + import cats.Monad import cats.implicits._ import fs2.Stream @@ -169,8 +171,8 @@ object CursorBuilder { def withEnv(env0: Env): Cursor = copy(env = env.add(env0)) override def isList: Boolean = true - override def asList: Result[List[Cursor]] = { - focus.traverse(elem => elemBuilder.build(context, elem, Some(this), env)) + override def asList[C](factory: Factory[Cursor, C]): Result[C] = { + focus.traverse(elem => elemBuilder.build(context, elem, Some(this), env)).map(_.to(factory)) } } @@ -207,7 +209,7 @@ abstract class AbstractCursor[T] extends Cursor { def isList: Boolean = false - def asList: Result[List[Cursor]] = + def asList[C](factory: Factory[Cursor, C]): Result[C] = mkErrorResult(s"Expected List type, found $tpe") def isNullable: Boolean = false diff --git a/modules/generic/src/test/scala/derivation.scala b/modules/generic/src/test/scala/derivation.scala index 8d0d8709..ca88463b 100644 --- a/modules/generic/src/test/scala/derivation.scala +++ b/modules/generic/src/test/scala/derivation.scala @@ -295,7 +295,7 @@ final class DerivationSpec extends CatsSuite { c <- CursorBuilder[Character].build(Context(CharacterType), lukeSkywalker) f <- c.field("appearsIn", None) n <- f.asNullable.flatMap(_.toRightIor(mkOneError("missing"))) - l <- n.asList + l <- n.asList(List) s <- l.traverse(_.asLeaf) } yield s assert(appearsIn == Ior.Right(List(Json.fromString("NEWHOPE"), Json.fromString("EMPIRE"), Json.fromString("JEDI")))) @@ -307,7 +307,7 @@ final class DerivationSpec extends CatsSuite { c <- CursorBuilder[Human].build(Context(HumanType), lukeSkywalker) f <- c.field("friends", None) n <- f.asNullable.flatMap(_.toRightIor(mkOneError("missing"))) - l <- n.asList + l <- n.asList(List) m <- l.traverse(_.field("name", None)) p <- m.traverse(_.asNullable.flatMap(_.toRightIor(mkOneError("missing")))) q <- p.traverse(_.asLeaf) diff --git a/modules/sql/src/main/scala/SqlMapping.scala b/modules/sql/src/main/scala/SqlMapping.scala index 320233e4..5a4eaec1 100644 --- a/modules/sql/src/main/scala/SqlMapping.scala +++ b/modules/sql/src/main/scala/SqlMapping.scala @@ -5,6 +5,7 @@ package edu.gemini.grackle package sql import scala.annotation.tailrec +import scala.collection.Factory import cats.data.{NonEmptyList, State} import cats.implicits._ @@ -211,6 +212,8 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => def scalaTypeName: String def pos: SourcePos + def resultPath: Option[List[String]] + /** The named owner of this column, if any */ def namedOwner: Option[TableExpr] = owner.findNamedOwner(this) @@ -306,7 +309,7 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => } /** Representation of a column of a table/view */ - case class TableColumn(owner: ColumnOwner, cr: ColumnRef) extends SqlColumn { + case class TableColumn(owner: ColumnOwner, cr: ColumnRef, resultPath: Option[List[String]]) extends SqlColumn { def column: String = cr.column def codec: Codec = cr.codec def scalaTypeName: String = cr.scalaTypeName @@ -334,8 +337,8 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => } object TableColumn { - def apply(context: Context, cr: ColumnRef): TableColumn = - TableColumn(TableRef(context, cr.table), cr) + def apply(context: Context, cr: ColumnRef, resultPath: Option[List[String]]): TableColumn = + TableColumn(TableRef(context, cr.table), cr, resultPath) } /** Representation of a synthetic null column @@ -348,6 +351,8 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => def scalaTypeName: String = col.scalaTypeName def pos: SourcePos = col.pos + def resultPath: Option[List[String]] = col.resultPath + override def underlying: SqlColumn = col.underlying def subst(from: ColumnOwner, to: ColumnOwner): SqlColumn = @@ -378,6 +383,8 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => def scalaTypeName: String = col.scalaTypeName def pos: SourcePos = col.pos + def resultPath: Option[List[String]] = col.resultPath + def subst(from: ColumnOwner, to: ColumnOwner): SqlColumn = { val subquery0 = (from, to) match { @@ -407,6 +414,8 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => def scalaTypeName: String = col.scalaTypeName def pos: SourcePos = col.pos + def resultPath: Option[List[String]] = col.resultPath + def subst(from: ColumnOwner, to: ColumnOwner): SqlColumn = copy(col.subst(from, to), cols.map(_.subst(from, to))) @@ -431,6 +440,8 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => def scalaTypeName: String = "Int" def pos: SourcePos = null + def resultPath: Option[List[String]] = None + def subst(from: ColumnOwner, to: ColumnOwner): SqlColumn = copy(owner = if(owner.isSameOwner(from)) to else owner, partitionCols = partitionCols.map(_.subst(from, to))) @@ -472,6 +483,8 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => def scalaTypeName: String = col.scalaTypeName def pos: SourcePos = col.pos + def resultPath: Option[List[String]] = col.resultPath + override def underlying: SqlColumn = col.underlying def subst(from: ColumnOwner, to: ColumnOwner): SqlColumn = @@ -503,6 +516,8 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => def scalaTypeName: String = col.scalaTypeName def pos: SourcePos = col.pos + def resultPath: Option[List[String]] = col.resultPath + override def underlying: SqlColumn = col.underlying def subst(from: ColumnOwner, to: ColumnOwner): SqlColumn = @@ -690,7 +705,7 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => /** Returns the discriminator columns for the context type */ def discriminatorColumnsForType(context: Context): List[SqlColumn] = objectMapping(context).map(_.fieldMappings.collect { - case cm: SqlField if cm.discriminator => SqlColumn.TableColumn(context, cm.columnRef) + case cm: SqlField if cm.discriminator => SqlColumn.TableColumn(context, cm.columnRef, mkResultPath(context, cm.fieldName)) }).getOrElse(Nil) /** Returns the key columns for the context type */ @@ -698,7 +713,7 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => val cols = objectMapping(context).map { obj => val objectKeys = obj.fieldMappings.collect { - case cm: SqlField if cm.key => SqlColumn.TableColumn(context, cm.columnRef) + case cm: SqlField if cm.key => SqlColumn.TableColumn(context, cm.columnRef, mkResultPath(context, cm.fieldName)) } val interfaceKeys = context.tpe.underlyingObject match { @@ -715,11 +730,14 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => cols } + def mkResultPath(context: Context, fieldName: String): Option[List[String]] = + Some(fieldName :: context.resultPath) + /** Returns the columns for leaf field `fieldName` in `context` */ def columnsForLeaf(context: Context, fieldName: String): List[SqlColumn] = fieldMapping(context, fieldName) match { - case Some(SqlField(_, cr, _, _, _, _)) => List(SqlColumn.TableColumn(context, cr)) - case Some(SqlJson(_, cr)) => List(SqlColumn.TableColumn(context, cr)) + case Some(SqlField(_, cr, _, _, _, _)) => List(SqlColumn.TableColumn(context, cr, mkResultPath(context, fieldName))) + case Some(SqlJson(_, cr)) => List(SqlColumn.TableColumn(context, cr, mkResultPath(context, fieldName))) case Some(CursorFieldJson(_, _, _, required, _)) => required.flatMap(r => columnsForLeaf(context, r)) case Some(CursorField(_, _, _, required, _)) => @@ -744,8 +762,8 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => /** Returns the aliased column corresponding to the atomic field `fieldName` in `context` */ def columnForAtomicField(context: Context, fieldName: String): Option[SqlColumn] = { fieldMapping(context, fieldName) match { - case Some(SqlField(_, cr, _, _, _, _)) => Some(SqlColumn.TableColumn(context, cr)) - case Some(SqlJson(_, cr)) => Some(SqlColumn.TableColumn(context, cr)) + case Some(SqlField(_, cr, _, _, _, _)) => Some(SqlColumn.TableColumn(context, cr, mkResultPath(context, fieldName))) + case Some(SqlJson(_, cr)) => Some(SqlColumn.TableColumn(context, cr, mkResultPath(context, fieldName))) case _ => None } } @@ -1568,14 +1586,14 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => } def mkJoins(joins0: List[Join], multiTable: Boolean): SqlSelect = { - val base = mkSubquery(multiTable, this, SqlColumn.TableColumn(table, joins0.last.child), "_nested") + val base = mkSubquery(multiTable, this, SqlColumn.TableColumn(table, joins0.last.child, None), "_nested") val initialJoins = joins0.init.map { j => val parentTable = TableRef(parentContext, j.parent.table) - val parentCol = SqlColumn.TableColumn(parentTable, j.parent) + val parentCol = SqlColumn.TableColumn(parentTable, j.parent, None) val childTable = TableRef(parentContext, j.child.table) - val childCol = SqlColumn.TableColumn(childTable, j.child) + val childCol = SqlColumn.TableColumn(childTable, j.child, None) SqlJoin( parentTable, childTable, @@ -1587,10 +1605,10 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => val finalJoins = { val lastJoin = joins0.last val parentTable = TableRef(parentContext, lastJoin.parent.table) - val parentCol = SqlColumn.TableColumn(parentTable, lastJoin.parent) + val parentCol = SqlColumn.TableColumn(parentTable, lastJoin.parent, None) if(!isAssociative(context)) { - val childCol = SqlColumn.TableColumn(base.table, lastJoin.child) + val childCol = SqlColumn.TableColumn(base.table, lastJoin.child, None) val finalJoin = SqlJoin( parentTable, @@ -1601,7 +1619,7 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => finalJoin :: Nil } else { val assocTable = TableExpr.DerivedTableRef(context, Some(base.table.name+"_assoc"), base.table, true) - val childCol = SqlColumn.TableColumn(assocTable, lastJoin.child) + val childCol = SqlColumn.TableColumn(assocTable, lastJoin.child, None) val assocJoin = SqlJoin( @@ -1669,11 +1687,11 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => case Some(SqlObject(_, firstJoin :: tail)) => val nested = mkJoins(tail, true) - val base = mkSubquery(false, nested, SqlColumn.TableColumn(nested.table, firstJoin.child), "_multi") + val base = mkSubquery(false, nested, SqlColumn.TableColumn(nested.table, firstJoin.child, None), "_multi") val initialJoin = { - val parentCol = SqlColumn.TableColumn(parentTable, firstJoin.parent) - val childCol = SqlColumn.TableColumn(base.table, firstJoin.child) + val parentCol = SqlColumn.TableColumn(parentTable, firstJoin.parent, None) + val childCol = SqlColumn.TableColumn(base.table, firstJoin.child, None) SqlJoin( parentTable, base.table, @@ -2511,11 +2529,16 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => final class MappedQuery( query: SqlQuery ) { + val index: Map[SqlColumn, Int] = query.cols.zipWithIndex.toMap + + val colsByResultPath: Map[List[String], List[(SqlColumn, Int)]] = + query.cols.filter(_.resultPath.isDefined).groupMap(_.resultPath.getOrElse(???))(col => (col, index(col))) + /** Execute this query in `F` */ def fetch: F[Table] = { for { rows <- self.fetch(fragment, query.codecs) - } yield Table(query.cols, rows) + } yield Table(rows) } /** The query rendered as a `Fragment` with all table and column aliases applied */ @@ -2523,29 +2546,16 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => /** Return the value of the field `fieldName` in `context` from `table` */ def selectAtomicField(context: Context, fieldName: String, table: Table): Result[Any] = - columnForAtomicField(context, fieldName) match { - case Some(col) => - table.select(col).filterNot(_ == FailedJoin).distinct match { - case Nil => FailedJoin.rightIor - case value :: Nil => value.rightIor - case multi => - val obj = context.tpe.dealias - if (obj.variantField(fieldName) || obj.field(fieldName).map(_.isNullable).getOrElse(true)) - // if the field is a non-schema attribute we won't be able to discover whether - // or not it's nullable. Instead we assume that the presense of a None implies - // nullability, hence stripping out Nones is justified. - multi.filterNot(_ == None) match { - case Nil => None.rightIor - case value :: Nil => value.rightIor - case multi => - mkErrorResult(s"Expected single value for field '$fieldName' of type $obj at ${context.path}, found $multi") - } - else - mkErrorResult(s"Expected single value for field '$fieldName' of type $obj at ${context.path}, found $multi") - } - case None => + leafIndex(context, fieldName) match { + case -1 => val obj = context.tpe.dealias mkErrorResult(s"Expected mapping for field '$fieldName' of type $obj") + + case col => + table.select(col).toRightIor( + mkOneError(s"Expected single value for field '$fieldName' of type ${context.tpe.dealias} at ${context.path}, found many") + ) + } /** Does `table` contain subobjects of the type of the `narrowedContext` type */ @@ -2570,6 +2580,70 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => */ def group(context: Context, table: Table): Iterator[Table] = table.group(keyColumnsForType(context)) + + def keyColumnsForType(context: Context): List[Int] = { + val key = context.resultPath + keyColumnsMemo.get(context.resultPath) match { + case Some(cols) => cols + case None => + val keys = SqlMapping.this.keyColumnsForType(context).map(index) + keyColumnsMemo.put(key, keys) + keys + } + } + + val keyColumnsMemo: scala.collection.mutable.HashMap[List[String], List[Int]] = + scala.collection.mutable.HashMap.empty[List[String], List[Int]] + + def leafIndex(context: Context, fieldName: String): Int = + colsByResultPath.get(fieldName :: context.resultPath) match { + case None => + columnForAtomicField(context, fieldName).flatMap(index.get).getOrElse(-1) + case Some(Nil) => -1 + case Some(List((_, i))) => i + case Some(cols) => + columnForAtomicField(context, fieldName).flatMap(cursorCol => + cols.find(_._1 == cursorCol).map(_._2) + ).getOrElse(-1) + } + + def encoderForLeaf(tpe: Type): Option[io.circe.Encoder[Any]] = + encoderMemo.get(tpe).orElse { + SqlMapping.this.encoderForLeaf(tpe) match { + case oe@Some(enc) => + encoderMemo.put(tpe, enc) + oe + case None => None + } + } + + val intTypeEncoder: io.circe.Encoder[Any] = + new io.circe.Encoder[Any] { + def apply(i: Any): Json = i match { + case i: Int => Json.fromInt(i) + case l: Long => Json.fromLong(l) + case other => sys.error(s"Not an Int: $other") + } + } + + val floatTypeEncoder: io.circe.Encoder[Any] = + new io.circe.Encoder[Any] { + def apply(f: Any): Json = f match { + case f: Float => Json.fromFloat(f).getOrElse(sys.error(s"Unrepresentable float $f")) + case d: Double => Json.fromDouble(d).getOrElse(sys.error(s"Unrepresentable double $d")) + case d: BigDecimal => Json.fromBigDecimal(d) + case other => sys.error(s"Not a Float: $other") + } + } + + val encoderMemo: scala.collection.mutable.HashMap[Type, io.circe.Encoder[Any]] = + scala.collection.mutable.HashMap( + ScalarType.StringType -> io.circe.Encoder[String].asInstanceOf[io.circe.Encoder[Any]], + ScalarType.IntType -> intTypeEncoder, + ScalarType.FloatType -> floatTypeEncoder, + ScalarType.BooleanType -> io.circe.Encoder[Boolean].asInstanceOf[io.circe.Encoder[Any]], + ScalarType.IDType -> io.circe.Encoder[String].asInstanceOf[io.circe.Encoder[Any]] + ) } object MappedQuery { @@ -2604,15 +2678,15 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => parentContext.forField(fieldName, resultName). getOrElse(sys.error(s"No field '$fieldName' of type ${context.tpe}")) - List((SqlColumn.TableColumn(parentContext, join.parent), SqlColumn.TableColumn(childContext, join.child))) + List((SqlColumn.TableColumn(parentContext, join.parent, None), SqlColumn.TableColumn(childContext, join.child, None))) case Some(SqlObject(_, joins)) => val init = joins.init.map { join => val parentTable = TableRef(parentContext, join.parent.table) - val parentCol = SqlColumn.TableColumn(parentTable, join.parent) + val parentCol = SqlColumn.TableColumn(parentTable, join.parent, None) val childTable = TableRef(parentContext, join.child.table) - val childCol = SqlColumn.TableColumn(childTable, join.child) + val childCol = SqlColumn.TableColumn(childTable, join.child, None) (parentCol, childCol) } @@ -2623,9 +2697,9 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => val lastJoin = joins.last val parentTable = TableRef(parentContext, lastJoin.parent.table) - val parentCol = SqlColumn.TableColumn(parentTable, lastJoin.parent) + val parentCol = SqlColumn.TableColumn(parentTable, lastJoin.parent, None) val childTable = TableRef(childContext, lastJoin.child.table) - val childCol = SqlColumn.TableColumn(childTable, lastJoin.child) + val childCol = SqlColumn.TableColumn(childTable, lastJoin.child, None) (parentCol, childCol) } @@ -2865,26 +2939,26 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => def numRows: Int def numCols: Int - /** Yield the values of the given column */ - def select(col: SqlColumn): List[Any] + /** Yield the value of the given column */ + def select(col: Int): Option[Any] /** A copy of this `Table` containing only the rows for which all the given columns are defined */ - def filterDefined(cols: List[SqlColumn]): Table + + def filterDefined(cols: List[Int]): Table + /** True if all the given columns are defined, false otherwise */ - def definesAll(cols: List[SqlColumn]): Boolean + def definesAll(cols: List[Int]): Boolean + /** Group this `Table` by the values of the given columns */ - def group(cols: List[SqlColumn]): Iterator[Table] + def group(cols: List[Int]): Iterator[Table] def isEmpty: Boolean = false } object Table { - def apply(cols: List[SqlColumn], rows: Vector[Array[Any]]): Table = - apply(cols.zipWithIndex.toMap, rows) - - def apply(index: SqlColumn => Int, rows: Vector[Array[Any]]): Table = { - if (rows.sizeCompare(1) == 0) OneRowTable(index, rows.head) + def apply(rows: Vector[Array[Any]]): Table = { + if (rows.sizeCompare(1) == 0) OneRowTable(rows.head) else if (rows.isEmpty) EmptyTable - else MultiRowTable(index, rows) + else MultiRowTable(rows) } /** Specialized representation of an empty table */ @@ -2892,36 +2966,31 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => def numRows: Int = 0 def numCols: Int = 0 - def select(col: SqlColumn): List[Any] = Nil - def filterDefined(cols: List[SqlColumn]): Table = this - def definesAll(cols: List[SqlColumn]): Boolean = false - def group(cols: List[SqlColumn]): Iterator[Table] = Iterator.empty[Table] + def select(col: Int): Option[Any] = Some(FailedJoin) + def filterDefined(cols: List[Int]): Table = this + def definesAll(cols: List[Int]): Boolean = false + def group(cols: List[Int]): Iterator[Table] = Iterator.empty[Table] override def isEmpty = true } /** Specialized representation of a table with exactly one row */ - case class OneRowTable(index: SqlColumn => Int, row: Array[Any]) extends Table { + case class OneRowTable(row: Array[Any]) extends Table { def numRows: Int = 1 def numCols = row.size - def select(col: SqlColumn): List[Any] = { - val c = index(col) - row(c) match { - case FailedJoin => Nil - case other => other :: Nil - } - } + def select(col: Int): Option[Any] = + Some(row(col)) - def filterDefined(cols: List[SqlColumn]): Table = + def filterDefined(cols: List[Int]): Table = if(definesAll(cols)) this else EmptyTable - def definesAll(cols: List[SqlColumn]): Boolean = { - val cs = cols.map(index) + def definesAll(cols: List[Int]): Boolean = { + val cs = cols cs.forall(c => row(c) != FailedJoin) } - def group(cols: List[SqlColumn]): Iterator[Table] = { + def group(cols: List[Int]): Iterator[Table] = { cols match { case Nil => Iterator.single(this) case cols => @@ -2931,30 +3000,42 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => } } - case class MultiRowTable(index: SqlColumn => Int, rows: Vector[Array[Any]]) extends Table { + case class MultiRowTable(rows: Vector[Array[Any]]) extends Table { def numRows = rows.size def numCols = rows.headOption.map(_.size).getOrElse(0) - def select(col: SqlColumn): List[Any] = { - val c = index(col) - rows.iterator.map(r => r(c)).filterNot(_ == FailedJoin).distinct.toList + def select(col: Int): Option[Any] = { + val c = col + var value: Any = FailedJoin + val ir = rows.iterator + while(ir.hasNext) { + ir.next()(c) match { + case FailedJoin => + case v if value == FailedJoin => value = v + case v if value == v => + case None => + case v@Some(_) if value == None => value = v + case _ => return None + } + } + Some(value) } - def filterDefined(cols: List[SqlColumn]): Table = { - val cs = cols.map(index) - Table(index, rows.filter(r => cs.forall(c => r(c) != FailedJoin))) + def filterDefined(cols: List[Int]): Table = { + val cs = cols + Table(rows.filter(r => cs.forall(c => r(c) != FailedJoin))) } - def definesAll(cols: List[SqlColumn]): Boolean = { - val cs = cols.map(index) + def definesAll(cols: List[Int]): Boolean = { + val cs = cols rows.exists(r => cs.forall(c => r(c) != FailedJoin)) } - def group(cols: List[SqlColumn]): Iterator[Table] = { + def group(cols: List[Int]): Iterator[Table] = { cols match { - case Nil => rows.iterator.map(r => OneRowTable(index, r)) + case Nil => rows.iterator.map(r => OneRowTable(r)) case cols => - val cs = cols.map(index) + val cs = cols val discrim: Array[Any] => Any = cs match { @@ -2965,7 +3046,7 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => val nonNull = rows.filter(r => cs.forall(c => r(c) != FailedJoin)) val grouped = nonNull.groupBy(discrim) - grouped.iterator.map { case (_, rows) => Table(index, rows) } + grouped.iterator.map { case (_, rows) => Table(rows) } } } } @@ -2983,30 +3064,7 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => def isLeaf: Boolean = tpe.isLeaf def asLeaf: Result[Json] = - encoderForLeaf(tpe).map(enc => enc(focus).rightIor).getOrElse( - focus match { - case s: String => Json.fromString(s).rightIor - case i: Int => Json.fromInt(i).rightIor - case l: Long => Json.fromLong(l).rightIor - case f: Float => Json.fromFloat(f) match { - case Some(j) => j.rightIor - case None => mkErrorResult(s"Unrepresentable float %d") - } - case d: Double => Json.fromDouble(d) match { - case Some(j) => j.rightIor - case None => mkErrorResult(s"Unrepresentable double %d") - } - case d: BigDecimal => Json.fromBigDecimal(d).rightIor - case b: Boolean => Json.fromBoolean(b).rightIor - - // This means we are looking at a column with no value because it's the result of a failed - // outer join. This is an implementation error. - case FailedJoin => sys.error("Unhandled failed join.") - - case other => - mkErrorResult(s"Not a leaf: $other") - } - ) + mapped.encoderForLeaf(tpe).map(enc => enc(focus).rightIor).getOrElse(mkErrorResult(s"Not a leaf: $focus")) def isList: Boolean = tpe match { @@ -3014,8 +3072,8 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => case _ => false } - def asList: Result[List[Cursor]] = (tpe, focus) match { - case (ListType(tpe), it: List[_]) => it.map(f => mkChild(context.asType(tpe), focus = f)).rightIor + def asList[C](factory: Factory[Cursor, C]): Result[C] = (tpe, focus) match { + case (ListType(tpe), it: List[_]) => it.view.map(f => mkChild(context.asType(tpe), focus = f)).to(factory).rightIor case _ => mkErrorResult(s"Expected List type, found $tpe") } @@ -3043,6 +3101,8 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => /** Cursor positioned at a GraphQL result non-leaf */ case class SqlCursor(context: Context, focus: Any, mapped: MappedQuery, parent: Option[Cursor], env: Env) extends Cursor { + assert(focus != Table.EmptyTable || context.tpe.isNullable || context.tpe.isList) + def withEnv(env0: Env): Cursor = copy(env = env.add(env0)) def mkChild(context: Context = context, focus: Any = focus): SqlCursor = @@ -3060,11 +3120,11 @@ trait SqlMapping[F[_]] extends CirceMapping[F] with SqlModule[F] { self => def isList: Boolean = tpe.isList - def asList: Result[List[Cursor]] = + def asList[C](factory: Factory[Cursor, C]): Result[C] = tpe.item.map(_.dealias).map(itemTpe => asTable.map { table => val itemContext = context.asType(itemTpe) - mapped.group(itemContext, table).map(table => mkChild(itemContext, focus = table)).to(List) + mapped.group(itemContext, table).map(table => mkChild(itemContext, focus = table)).to(factory) } ).getOrElse(mkErrorResult(s"Not a list: $tpe"))