Skip to content

Commit

Permalink
Merge pull request #634 from typelevel/topic/poly-fields
Browse files Browse the repository at this point in the history
Fix for polymorphic fields in SqlMapping
  • Loading branch information
milessabin authored Sep 14, 2024
2 parents 413807c + 6f5a7dd commit 5df3486
Show file tree
Hide file tree
Showing 19 changed files with 781 additions and 214 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ ThisBuild / scalaVersion := Scala2
ThisBuild / crossScalaVersions := Seq(Scala2, Scala3)
ThisBuild / tlJdkRelease := Some(11)

ThisBuild / tlBaseVersion := "0.21"
ThisBuild / tlBaseVersion := "0.22"
ThisBuild / startYear := Some(2019)
ThisBuild / licenses := Seq(License.Apache2)
ThisBuild / developers := List(
Expand Down
16 changes: 9 additions & 7 deletions modules/circe/src/main/scala/circemapping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -152,22 +152,24 @@ trait CirceMappingLike[F[_]] extends Mapping[F] {
case _ => Result.internalError(s"Expected Nullable type, found $focus for $tpe")
}

def narrowsTo(subtpe: TypeRef): Boolean =
subtpe <:< tpe &&
def narrowsTo(subtpe: TypeRef): Result[Boolean] =
(subtpe <:< tpe &&
((subtpe.dealias, focus.asObject) match {
case (nt: TypeWithFields, Some(obj)) =>
nt.fields.forall { f =>
f.tpe.isNullable || obj.contains(f.name)
} && obj.keys.forall(nt.hasField)

case _ => false
})
})).success

def narrow(subtpe: TypeRef): Result[Cursor] =
if (narrowsTo(subtpe))
mkChild(context.asType(subtpe)).success
else
Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
narrowsTo(subtpe).flatMap { n =>
if (n)
mkChild(context.asType(subtpe)).success
else
Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
}

def field(fieldName: String, resultName: Option[String]): Result[Cursor] = {
val localField =
Expand Down
6 changes: 3 additions & 3 deletions modules/core/src/main/scala/cursor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ trait Cursor {
def isDefined: Result[Boolean]

/** Is the value at this `Cursor` narrowable to `subtpe`? */
def narrowsTo(subtpe: TypeRef): Boolean
def narrowsTo(subtpe: TypeRef): Result[Boolean]

/**
* Yield a `Cursor` corresponding to the value at this `Cursor` narrowed to
Expand Down Expand Up @@ -251,7 +251,7 @@ object Cursor {
def isDefined: Result[Boolean] =
Result.internalError(s"Expected Nullable type, found $focus for $tpe")

def narrowsTo(subtpe: TypeRef): Boolean = false
def narrowsTo(subtpe: TypeRef): Result[Boolean] = false.success

def narrow(subtpe: TypeRef): Result[Cursor] =
Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
Expand Down Expand Up @@ -290,7 +290,7 @@ object Cursor {

def isDefined: Result[Boolean] = underlying.isDefined

def narrowsTo(subtpe: TypeRef): Boolean = underlying.narrowsTo(subtpe)
def narrowsTo(subtpe: TypeRef): Result[Boolean] = underlying.narrowsTo(subtpe)

def narrow(subtpe: TypeRef): Result[Cursor] = underlying.narrow(subtpe)

Expand Down
54 changes: 33 additions & 21 deletions modules/core/src/main/scala/mapping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,13 @@ abstract class Mapping[F[_]] {
* the `fieldName` child of `parent`.
*/
protected final def mkCursorForField(parent: Cursor, fieldName: String, resultName: Option[String]): Result[Cursor] = {
val context = parent.context
val fieldContext = context.forFieldOrAttribute(fieldName, resultName)

typeMappings.fieldMapping(parent, fieldName).
toResultOrError(s"No mapping for field '$fieldName' for type ${parent.tpe}").
flatMap(mkCursorForMappedField(parent, fieldContext, _))
flatMap(_.toResultOrError(s"No mapping for field '$fieldName' for type ${parent.tpe}")).
flatMap {
case (np, fm) =>
val fieldContext = np.context.forFieldOrAttribute(fieldName, resultName)
mkCursorForMappedField(np, fieldContext, fm)
}
}

final class TypeMappings private (
Expand Down Expand Up @@ -200,18 +201,25 @@ abstract class Mapping[F[_]] {
* Yields the `FieldMapping` associated with `fieldName` in the runtime context
* determined by the given `Cursor`, if any.
*/
def fieldMapping(parent: Cursor, fieldName: String): Option[FieldMapping] = {
def fieldMapping(parent: Cursor, fieldName: String): Result[Option[(Cursor, FieldMapping)]] = {
val context = parent.context
fieldIndex(context).flatMap(_.get(fieldName)).flatMap {
case ifm: InheritedFieldMapping =>
ifm.select(parent.context)
case pfm: PolymorphicFieldMapping =>
fieldIndex(context).flatMap(_.get(fieldName)) match {
case Some(ifm: InheritedFieldMapping) =>
ifm.select(parent.context).map((parent, _)).success
case Some(pfm: PolymorphicFieldMapping) =>
pfm.select(parent)
case fm =>
Some(fm)
case Some(fm) =>
Option((parent, fm)).success
case None => None.success
}
}

def fieldIsPolymorphic(context: Context, fieldName: String): Boolean =
rawFieldMapping(context, fieldName).exists {
case _: PolymorphicFieldMapping => true
case _ => false
}

/** Yields the `FieldMapping` directly or ancestrally associated with `fieldName` in `context`, if any. */
def ancestralFieldMapping(context: Context, fieldName: String): Option[FieldMapping] =
fieldMapping(context, fieldName).orElse {
Expand Down Expand Up @@ -532,16 +540,20 @@ abstract class Mapping[F[_]] {
def hidden: Boolean = false
def subtree: Boolean = false

def select(cursor: Cursor): Option[FieldMapping] = {
val context = cursor.context
def select(cursor: Cursor): Result[Option[(Cursor, FieldMapping)]] = {
val applicable =
candidates.mapFilter {
case (pred, fm) if cursor.narrowsTo(schema.uncheckedRef(pred.tpe)) =>
pred(context.asType(pred.tpe)).map(prio => (prio, fm))
case _ =>
None
candidates.traverseFilter {
case (pred, fm) =>
cursor.narrowsTo(schema.uncheckedRef(pred.tpe)).flatMap { narrows =>
if (narrows)
for {
nc <- cursor.narrow(schema.uncheckedRef(pred.tpe))
} yield pred(nc.context).map(prio => (prio, (nc, fm)))
else
None.success
}
}
applicable.maxByOption(_._1).map(_._2)
applicable.map(_.maxByOption(_._1).map(_._2))
}

def select(context: Context): Option[FieldMapping] = {
Expand Down Expand Up @@ -1266,7 +1278,7 @@ abstract class Mapping[F[_]] {
case _ => Result.internalError(s"Not nullable at ${context.path}")
}

def narrowsTo(subtpe: TypeRef): Boolean = false
def narrowsTo(subtpe: TypeRef): Result[Boolean] = false.success
def narrow(subtpe: TypeRef): Result[Cursor] =
Result.failure(s"Cannot narrow $tpe to $subtpe")

Expand Down
29 changes: 0 additions & 29 deletions modules/core/src/main/scala/query.scala
Original file line number Diff line number Diff line change
Expand Up @@ -456,35 +456,6 @@ object Query {
}
}

/** Extractor for grouped Narrow patterns in the query algebra */
object TypeCase {
def unapply(q: Query): Option[(Query, List[Narrow])] = {
def branch(q: Query): Option[TypeRef] =
q match {
case Narrow(subtpe, _) => Some(subtpe)
case _ => None
}

val grouped = ungroup(q).groupBy(branch).toList
val (default0, narrows0) = grouped.partition(_._1.isEmpty)
if (narrows0.isEmpty) None
else {
val default = default0.flatMap(_._2) match {
case Nil => Empty
case children => Group(children)
}
val narrows = narrows0.collect {
case (Some(subtpe), narrows) =>
narrows.collect { case Narrow(_, child) => child } match {
case List(child) => Narrow(subtpe, child)
case children => Narrow(subtpe, Group(children))
}
}
Some((default, narrows))
}
}
}

/** Construct a query which yields all the supplied paths */
def mkPathQuery(paths: List[List[String]]): List[Query] =
paths match {
Expand Down
45 changes: 24 additions & 21 deletions modules/core/src/main/scala/queryinterpreter.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,22 +201,23 @@ class QueryInterpreter[F[_]](mapping: Mapping[F]) {
siblings.flatTraverse(query => runFields(query, tpe, cursor))

case Introspect(schema, s@Select("__typename", _, Empty)) if tpe.isNamed =>
(tpe.dealias match {
case o: ObjectType => Some(o.name)
val fail = Result.failure(s"'__typename' cannot be applied to non-selectable type '$tpe'")
def mkTypeNameFields(name: String) =
List((s.resultName, ProtoJson.fromJson(Json.fromString(name)))).success
def mkTypeNameFieldsOrFail(name: Option[String]) =
name.map(mkTypeNameFields).getOrElse(fail)

tpe.dealias match {
case o: ObjectType => mkTypeNameFields(o.name)
case i: InterfaceType =>
(schema.implementations(i).collectFirst {
case o if cursor.narrowsTo(schema.uncheckedRef(o)) => o.name
})
schema.implementations(i).collectFirstSomeM { o =>
cursor.narrowsTo(schema.uncheckedRef(o)).ifF(Some(o.name), None)
}.flatMap(mkTypeNameFieldsOrFail)
case u: UnionType =>
(u.members.map(_.dealias).collectFirst {
case nt: NamedType if cursor.narrowsTo(schema.uncheckedRef(nt)) => nt.name
})
case _ => None
}) match {
case Some(name) =>
List((s.resultName, ProtoJson.fromJson(Json.fromString(name)))).success
case None =>
Result.failure(s"'__typename' cannot be applied to non-selectable type '$tpe'")
u.members.map(_.dealias).collectFirstSomeM { nt =>
cursor.narrowsTo(schema.uncheckedRef(nt)).ifF(Some(nt.name), None)
}.flatMap(mkTypeNameFieldsOrFail)
case _ => fail
}

case sel: Select if tpe.isNullable =>
Expand Down Expand Up @@ -250,13 +251,15 @@ class QueryInterpreter[F[_]](mapping: Mapping[F]) {
value <- runValue(child, fieldTpe, c)
} yield List((sel.resultName, value))

case Narrow(tp1, child) if cursor.narrowsTo(tp1) =>
for {
c <- cursor.narrow(tp1)
fields <- runFields(child, tp1, c)
} yield fields

case _: Narrow => Nil.success
case Narrow(tp1, child) =>
cursor.narrowsTo(tp1).flatMap { n =>
if (!n) Nil.success
else
for {
c <- cursor.narrow(tp1)
fields <- runFields(child, tp1, c)
} yield fields
}

case c@Component(_, _, cont) =>
for {
Expand Down
16 changes: 9 additions & 7 deletions modules/core/src/main/scala/valuemapping.scala
Original file line number Diff line number Diff line change
Expand Up @@ -175,20 +175,22 @@ trait ValueMappingLike[F[_]] extends Mapping[F] {
case _ => Result.internalError(s"Expected Nullable type, found $focus for $tpe")
}

def narrowsTo(subtpe: TypeRef): Boolean =
subtpe <:< tpe &&
def narrowsTo(subtpe: TypeRef): Result[Boolean] =
(subtpe <:< tpe &&
objectMapping(context.asType(subtpe)).exists {
case ValueObjectMapping(_, _, classTag) =>
classTag.runtimeClass.isInstance(focus)
case _ => false
}
}).success


def narrow(subtpe: TypeRef): Result[Cursor] =
if (narrowsTo(subtpe))
mkChild(context.asType(subtpe)).success
else
Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
narrowsTo(subtpe).flatMap { n =>
if(n)
mkChild(context.asType(subtpe)).success
else
Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
}

def field(fieldName: String, resultName: Option[String]): Result[Cursor] =
mkCursorForField(this, fieldName, resultName)
Expand Down
10 changes: 6 additions & 4 deletions modules/generic/src/main/scala-2/genericmapping2.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,14 @@ trait ScalaVersionSpecificGenericMappingLike[F[_]] extends Mapping[F] { self: Ge
override def field(fieldName: String, resultName: Option[String]): Result[Cursor] =
cursor.field(fieldName, resultName) orElse mkCursorForField(this, fieldName, resultName)

override def narrowsTo(subtpe: TypeRef): Boolean =
subtpe <:< tpe && rtpe <:< subtpe
override def narrowsTo(subtpe: TypeRef): Result[Boolean] =
(subtpe <:< tpe && rtpe <:< subtpe).success

override def narrow(subtpe: TypeRef): Result[Cursor] =
if (narrowsTo(subtpe)) copy(tpe0 = subtpe).success
else Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
narrowsTo(subtpe).flatMap { n =>
if (n) copy(tpe0 = subtpe).success
else Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
}
}
}
}
Expand Down
10 changes: 6 additions & 4 deletions modules/generic/src/main/scala-3/genericmapping3.scala
Original file line number Diff line number Diff line change
Expand Up @@ -116,12 +116,14 @@ trait ScalaVersionSpecificGenericMappingLike[F[_]] extends Mapping[F] { self: Ge
override def field(fieldName: String, resultName: Option[String]): Result[Cursor] =
cursor.field(fieldName, resultName) orElse mkCursorForField(this, fieldName, resultName)

override def narrowsTo(subtpe: TypeRef): Boolean =
subtpe <:< tpe && rtpe <:< subtpe
override def narrowsTo(subtpe: TypeRef): Result[Boolean] =
(subtpe <:< tpe && rtpe <:< subtpe).success

override def narrow(subtpe: TypeRef): Result[Cursor] =
if (narrowsTo(subtpe)) copy(tpe0 = subtpe).success
else Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
narrowsTo(subtpe).flatMap { n =>
if (n) copy(tpe0 = subtpe).success
else Result.internalError(s"Focus ${focus} of static type $tpe cannot be narrowed to $subtpe")
}
}
}
}
Loading

0 comments on commit 5df3486

Please sign in to comment.