Skip to content

Commit

Permalink
Merge pull request #52 from tarao/fix-opaque-dealiasing
Browse files Browse the repository at this point in the history
Avoid opaque type dealiasing
  • Loading branch information
tarao authored Dec 16, 2023
2 parents dc12136 + 9f1b7d6 commit c5efae8
Show file tree
Hide file tree
Showing 6 changed files with 280 additions and 107 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ProjectKeys._
import Implicits._

ThisBuild / tlBaseVersion := "0.9"
ThisBuild / tlBaseVersion := "0.10"

ThisBuild / projectName := "record4s"
ThisBuild / groupId := "com.github.tarao"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -116,25 +116,10 @@ object ArrayRecordMacros {
def derivedRecordLikeImpl[R: Type](using
Quotes,
): Expr[RecordLike[ArrayRecord[R]]] = withInternal {
import quotes.reflect.*
import internal.*

val schema = schemaOfRecord[R]
val base = (Type.of[EmptyTuple]: Type[?], Type.of[EmptyTuple]: Type[?])
val (elemLabels, elemTypes) = schema
.fieldTypes
.foldRight(base) { case ((label, tpe), (baseLabels, baseTypes)) =>
val labels =
(baseLabels, ConstantType(StringConstant(label)).asType) match {
case ('[EmptyTuple], '[label]) => Type.of[label *: EmptyTuple]
case ('[head *: tail], '[label]) => Type.of[label *: head *: tail]
}
val types = (baseTypes, tpe) match {
case ('[EmptyTuple], '[tpe]) => Type.of[tpe *: EmptyTuple]
case ('[head *: tail], '[tpe]) => Type.of[tpe *: head *: tail]
}
(labels, types)
}
val (elemLabels, elemTypes) = schema.asUnzippedTupleType

(elemLabels, elemTypes, schema.tagsAsType, schema.asTupleType) match {
case ('[elemLabels], '[elemTypes], '[tagsType], '[tupleType]) =>
Expand Down Expand Up @@ -167,7 +152,7 @@ object ArrayRecordMacros {
val newSchema = (schema1 ++ schema2).deduped
if (schema1.size + schema2.size != newSchema.size)
deduped = true
(schema1 ++ schema2).deduped.asTupleType
newSchema.asTupleType
}

val needDedupType =
Expand Down Expand Up @@ -220,12 +205,11 @@ object ArrayRecordMacros {
}

val schema = schemaOfRecord[R]
val ((_, tpe), index) =
schema.fieldTypes.zipWithIndex.find(_._1._1 == label).getOrElse {
errorAndAbort(
s"Value '${label}' is not a member of ${Type.show[R]}",
)
}
val (tpe, index) = schema.findWithIndex(label).getOrElse {
errorAndAbort(
s"Value '${label}' is not a member of ${Type.show[R]}",
)
}
val indexType = ConstantType(IntConstant(index)).asType

(tpe, indexType) match {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,43 @@ private[record4s] class InternalMacros(using
TypingResult.error(e.getMessage())
}

case class Schema(
fieldTypes: Seq[(String, Type[?])],
private def typeOf(tpr: TypeRepr): Type[?] =
tpr.asType match { case '[tpe] => Type.of[tpe] }

private def typeReprOf(tpe: Type[?]): TypeRepr =
tpe match { case '[tpe] => TypeRepr.of[tpe] }

case class Schema private[InternalMacros] (
private[InternalMacros] val fieldTypes: Seq[(String, TypeRepr)],
tags: Seq[Type[?]],
) {
def size: Int = fieldTypes.size

private[InternalMacros] def appended(label: String, tpe: TypeRepr): Schema =
copy(fieldTypes = fieldTypes :+ (label, tpe))

private[InternalMacros] def prepended(
label: String,
tpe: TypeRepr,
): Schema =
copy(fieldTypes = (label, tpe) +: fieldTypes)

def ++(other: Schema): Schema = copy(
fieldTypes = fieldTypes ++ other.fieldTypes,
tags = tags ++ other.tags,
)

def ++(other: Seq[(String, Type[?])]): Schema = copy(
fieldTypes = fieldTypes ++ other,
)
def find(label: String): Option[Type[?]] =
fieldTypes.find(_._1 == label).map(f => typeOf(f._2))

def findWithIndex(label: String): Option[(Type[?], Int)] =
fieldTypes.zipWithIndex.find(_._1._1 == label).map {
case ((_, tpr), index) =>
(typeOf(tpr), index)
}

def filterByLabel(pred: String => Boolean): Schema =
copy(fieldTypes = fieldTypes.filter(f => pred(f._1)))

def deduped: Schema =
copy(fieldTypes = fieldTypes.deduped.iterator.toSeq)
Expand All @@ -90,25 +113,50 @@ private[record4s] class InternalMacros(using
// & { val ${schema(1)._1}: ${schema(1)._2} })
// ...)
val record = fieldTypes
.foldLeft(baseRepr) { case (base, (label, '[tpe])) =>
Refinement(base, label, TypeRepr.of[tpe])
.foldLeft(baseRepr) { case (base, (label, tpr)) =>
Refinement(base, label, tpr)
}
tagsWith(record).asType
}

def asTupleType: Type[?] = {
val cons = TypeRepr.of[Nothing *: EmptyTuple] match {
case AppliedType(c, _) => c
}
def makeCons(car: TypeRepr, cdr: TypeRepr): TypeRepr =
AppliedType(cons, List(car, cdr))

val tuple2 = TypeRepr.of[(Nothing, Nothing)] match {
case AppliedType(c, _) => c
}
def makeTuple2(fst: TypeRepr, snd: TypeRepr): TypeRepr =
AppliedType(tuple2, List(fst, snd))

val record = fieldTypes.foldRight(TypeRepr.of[EmptyTuple]) {
case ((label, '[tpe]), rest) =>
(ConstantType(StringConstant(label)).asType, rest.asType) match {
case ('[labelType], '[head *: tail]) =>
TypeRepr.of[(labelType, tpe) *: head *: tail]
case ('[labelType], '[EmptyTuple]) =>
TypeRepr.of[(labelType, tpe) *: EmptyTuple]
}
case ((label, tpr), rest) =>
makeCons(makeTuple2(ConstantType(StringConstant(label)), tpr), rest)
}
tagsWith(record).asType
}

def asUnzippedTupleType: (Type[?], Type[?]) = {
val base = (Type.of[EmptyTuple]: Type[?], Type.of[EmptyTuple]: Type[?])
fieldTypes.foldRight(base) {
case ((label, tpr), (baseLabels, baseTypes)) =>
val tpe = typeOf(tpr)
val labels =
(baseLabels, ConstantType(StringConstant(label)).asType) match {
case ('[EmptyTuple], '[label]) => Type.of[label *: EmptyTuple]
case ('[head *: tail], '[label]) => Type.of[label *: head *: tail]
}
val types = (baseTypes, tpe) match {
case ('[EmptyTuple], '[tpe]) => Type.of[tpe *: EmptyTuple]
case ('[head *: tail], '[tpe]) => Type.of[tpe *: head *: tail]
}
(labels, types)
}
}

def tagsAsType: Type[?] = tagsWith(TypeRepr.of[Any]).asType

private def tagsWith(tpr: TypeRepr): TypeRepr =
Expand Down Expand Up @@ -164,25 +212,71 @@ private[record4s] class InternalMacros(using
false
}

private def isTuple(tpr: TypeRepr): Boolean =
tpr.asType match {
case '[_ *: _] => true
case _ => false
}

private def isOpaqueAlias(tpr: TypeRepr): Boolean =
tpr match {
case ref @ TypeRef(_, _) => ref.isOpaqueAlias
case _ => false
}

private def fixupOpaqueAlias(tpr: TypeRepr): TypeRepr = {
def rec(tpr: TypeRepr): TypeRepr =
tpr match {
case ref @ TypeRef(_, _) if ref.isOpaqueAlias =>
// Resolve `$proxy1.SomeOpaqueAlias` to fully qualified RefType
tpr.typeSymbol.typeRef
case SuperType(thisTpr, superTpr) =>
SuperType(rec(thisTpr), rec(superTpr))
case Refinement(parent, name, info) =>
Refinement(rec(parent), name, rec(info))
case AppliedType(tycon, args) =>
AppliedType(rec(tycon), args.map(rec(_)))
case AnnotatedType(underlying, annot) =>
AnnotatedType(rec(underlying), annot)
case AndType(lhs, rhs) =>
AndType(rec(lhs), rec(rhs))
case OrType(lhs, rhs) =>
OrType(rec(lhs), rec(rhs))
case MatchType(bound, scrutinee, cases) =>
MatchType(rec(bound), rec(scrutinee), cases.map(rec(_)))
case ByNameType(underlying) =>
ByNameType(rec(underlying))
case MatchCase(pattern, rhs) =>
MatchCase(rec(pattern), rec(rhs))
case TypeBounds(low, hi) =>
TypeBounds(rec(low), rec(hi))
case _ =>
tpr
}

rec(tpr)
}

private def isAlias(tpr: TypeRepr): Boolean =
tpr.typeSymbol.isAliasType && !isOpaqueAlias(tpr)

def traverse[R: Type, Acc](acc: Acc, f: (Acc, Type[?]) => Acc): Acc = {
def safeDealias(tpr: TypeRepr): TypeRepr =
if (isTag(tpr)) tpr
else tpr.dealias
if (isAlias(tpr)) tpr.dealias
else tpr

val nothing = TypeRepr.of[Nothing]

@tailrec def traverseTuple(
tpe: Type[?],
acc: Acc,
): Acc = tpe match {
case '[(labelType, valueType) *: rest]
case '[head *: rest]
// Type variable or Nothing always matches with `Nothing *: Nothing`
if TypeRepr.of[labelType] != nothing
&& TypeRepr.of[valueType] != nothing
&& TypeRepr.of[rest] != nothing =>
if TypeRepr.of[head] != nothing && TypeRepr.of[rest] != nothing =>
traverseTuple(
Type.of[rest],
f(acc, Type.of[(labelType, valueType)]),
f(acc, Type.of[head]),
)
case _ =>
f(acc, tpe)
Expand Down Expand Up @@ -239,38 +333,53 @@ private[record4s] class InternalMacros(using
acc
}

traverse1(List(safeDealias(TypeRepr.of[R])), acc)
traverse1(List(safeDealias(fixupOpaqueAlias(TypeRepr.of[R]))), acc)
}

def schemaOfRecord[R: Type]: Schema = {
def unapplyTuple2(tpr: TypeRepr): Option[(TypeRepr, TypeRepr)] =
// We can't do
//
// ```
// tpr.asType match {
// case '[fst *: snd] =>
// Some((TypeRepr.of[fst], TypeRepr.of[snd]))
// }
// ```
//
// because that will dealiases opaque type aliases
tpr match {
case AppliedType(c, fst :: snd :: _)
if c.typeSymbol.fullName == "scala.Tuple2" =>
Some((fst, snd))
case AppliedType(c, fst :: AppliedType(_, snd :: _) :: _)
if c.typeSymbol.fullName == "scala.*:" =>
Some((fst, snd))
case _ =>
None
}

traverse[R, Schema](
Schema.empty,
(acc: Schema, tpe: Type[?]) => {
tpe match {
case '[(labelType, valueType)] =>
TypeRepr.of[labelType] match {
case ConstantType(StringConstant(label)) =>
acc.copy(fieldTypes =
acc.fieldTypes :+ (validatedLabel(label), Type.of[valueType]),
)
typeReprOf(tpe) match {
case Refinement(_, label, valueType) =>
acc.prepended(validatedLabel(label), valueType)

case tpr if isTuple(tpr) =>
unapplyTuple2(tpr) match {
case Some((ConstantType(StringConstant(label)), valueType)) =>
acc.appended(validatedLabel(label), valueType)
case _ =>
acc
}

case '[tpe] =>
TypeRepr.of[tpe] match {
case Refinement(_, label, valueType) =>
acc.copy(fieldTypes =
(validatedLabel(label), valueType.asType) +: acc.fieldTypes,
)

// Tag[T]
case tpr @ AppliedType(_, List(tag)) if isTag(tpr) =>
acc.copy(tags = tag.asType +: acc.tags)
// Tag[T]
case tpr @ AppliedType(_, List(tag)) if isTag(tpr) =>
acc.copy(tags = tag.asType +: acc.tags)

case _ =>
acc
}
case _ =>
acc
}
},
)
Expand Down Expand Up @@ -395,9 +504,8 @@ private[record4s] class InternalMacros(using
}
}

def fieldSelectionsOf[S: Type](
schema: Schema,
): Seq[(String, String, Type[?])] = {
def selectedSchemaOf[R: Type, S: Type]: Schema = {
val schema = schemaOf[R]
val fieldTypeMap = schema.fieldTypes.toMap

def normalize(t: Type[?]): (TypeRepr, TypeRepr) = t match {
Expand All @@ -407,8 +515,8 @@ private[record4s] class InternalMacros(using

@tailrec def fieldTypes(
t: Type[?],
acc: Seq[(String, String, Type[?])],
): Seq[(String, String, Type[?])] =
acc: Seq[(String, TypeRepr)],
): Seq[(String, TypeRepr)] =
t match {
case '[head *: tail] =>
normalize(Type.of[head]) match {
Expand All @@ -420,7 +528,7 @@ private[record4s] class InternalMacros(using
label,
errorAndAbort(s"Missing key '${label}'"),
)
fieldTypes(Type.of[tail], acc :+ (label, renamed, fieldType))
fieldTypes(Type.of[tail], acc :+ (renamed, fieldType))
case _ =>
errorAndAbort(
"Selector type element must be a literal (possibly paired) label",
Expand All @@ -432,12 +540,12 @@ private[record4s] class InternalMacros(using
errorAndAbort("Selector type must be a Tuple")
}

fieldTypes(Type.of[S], Seq.empty)
schema.copy(fieldTypes = fieldTypes(Type.of[S], Seq.empty))
}

def fieldUnselectionsOf[U <: Tuple: Type](
schema: Schema,
): Seq[(String, Type[?])] = {
def unselectedSchemaOf[R: Type, U <: Tuple: Type]: Schema = {
val schema = schemaOf[R]

@tailrec def unselectedLabelsOf[U <: Tuple: Type](
acc: Set[String],
): Set[String] =
Expand All @@ -456,7 +564,7 @@ private[record4s] class InternalMacros(using
}
val unselected = unselectedLabelsOf[U](Set.empty)

schema.fieldTypes.filterNot((label, _) => unselected.contains(label))
schema.filterByLabel(label => !unselected.contains(label))
}
}

Expand Down
Loading

0 comments on commit c5efae8

Please sign in to comment.