Skip to content

Commit

Permalink
Separate traversal and accumulation.
Browse files Browse the repository at this point in the history
  • Loading branch information
tarao committed Nov 28, 2023
1 parent 948017f commit 8008f4e
Showing 1 changed file with 68 additions and 43 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -151,52 +151,47 @@ private[record4s] class InternalMacros(using
)
}

def schemaOfRecord[R: Type]: Schema = {
// Check if tpr represents Tag[T]: we need to check IsTag[Tag[T]] given instance
// because representation of opaque type varies among different package names such as
// Tag$package.Tag[T] or $proxyN.Tag[T].
def isTag(tpr: TypeRepr): Boolean =
tpr match {
case AppliedType(t, _) =>
tpr.asType match {
case '[tpe] => Expr.summon[Tag.IsTag[tpe]].nonEmpty
}
case _ =>
false
}
// Check if tpr represents Tag[T]: we need to check IsTag[Tag[T]] given instance
// because representation of opaque type varies among different package names such as
// Tag$package.Tag[T] or $proxyN.Tag[T].
private def isTag(tpr: TypeRepr): Boolean =
tpr match {
case AppliedType(_, _) =>
tpr.asType match {
case '[tpe] => Expr.summon[Tag.IsTag[tpe]].nonEmpty
}
case _ =>
false
}

def traverse[R: Type, Acc](acc: Acc, f: (Acc, Type[?]) => Acc): Acc = {
def safeDealias(tpr: TypeRepr): TypeRepr =
if (isTag(tpr)) tpr
else tpr.dealias

val nothing = TypeRepr.of[Nothing]

@tailrec def collectTupledFieldTypes(
@tailrec def traverseTuple(
tpe: Type[?],
acc: Seq[(String, Type[?])],
): Seq[(String, Type[?])] = tpe match {
acc: Acc,
): Acc = tpe match {
case '[(labelType, valueType) *: rest]
// Type variable or Nothing always matches with `Nothing *: Nothing`
if TypeRepr.of[labelType] != nothing
&& TypeRepr.of[valueType] != nothing
&& TypeRepr.of[rest] != nothing =>
TypeRepr.of[labelType] match {
case ConstantType(StringConstant(label)) =>
collectTupledFieldTypes(
Type.of[rest],
acc :+ (validatedLabel(label), Type.of[valueType]),
)
case _ =>
collectTupledFieldTypes(Type.of[rest], acc)
}
traverseTuple(
Type.of[rest],
f(acc, Type.of[(labelType, valueType)]),
)
case _ =>
acc
f(acc, tpe)
}

@tailrec def collectFieldTypesAndTags(
@tailrec def traverse1(
reversed: List[TypeRepr],
acc: Schema,
): Schema = reversed match {
acc: Acc,
): Acc = reversed match {
// base { label: valueType }
// For example
// TypeRepr.of[%{val name: String; val age: Int}]
Expand All @@ -210,45 +205,75 @@ private[record4s] class InternalMacros(using
// "age",
// TypeRepr.of[Int]
// )
case Refinement(base, label, valueType) :: rest =>
collectFieldTypesAndTags(
case (tpr @ Refinement(base, _, _)) :: rest =>
traverse1(
safeDealias(base) :: rest,
acc.copy(fieldTypes =
(validatedLabel(label), valueType.asType) +: acc.fieldTypes,
),
f(acc, tpr.asType),
)

// tpr1 & tpr2
case AndType(tpr1, tpr2) :: rest =>
collectFieldTypesAndTags(
traverse1(
safeDealias(tpr2) :: safeDealias(tpr1) :: rest,
acc,
)

// Tag[T]
case (head @ AppliedType(_, List(tpr))) :: rest if isTag(head) =>
collectFieldTypesAndTags(
case head :: rest if isTag(head) =>
traverse1(
rest,
acc.copy(tags = tpr.asType +: acc.tags),
f(acc, head.asType),
)

// typically `%` in `% { ... }` or
// (tp1, ...)
// tp1 *: ...
case head :: rest =>
collectFieldTypesAndTags(
traverse1(
rest,
acc.copy(fieldTypes =
collectTupledFieldTypes(head.asType, Seq.empty) ++ acc.fieldTypes,
),
traverseTuple(head.asType, acc),
)

// all done
case Nil =>
acc
}

collectFieldTypesAndTags(List(safeDealias(TypeRepr.of[R])), Schema.empty)
traverse1(List(safeDealias(TypeRepr.of[R])), acc)
}

def schemaOfRecord[R: Type]: Schema = {
traverse[R, Schema](
Schema.empty,
(acc: Schema, tpe: Type[?]) => {
tpe match {
case '[(labelType, valueType)] =>
TypeRepr.of[labelType] match {
case ConstantType(StringConstant(label)) =>
acc.copy(fieldTypes =
acc.fieldTypes :+ (validatedLabel(label), Type.of[valueType]),
)
case _ =>
acc
}

case '[tpe] =>
TypeRepr.of[tpe] match {
case Refinement(_, label, valueType) =>
acc.copy(fieldTypes =
(validatedLabel(label), valueType.asType) +: acc.fieldTypes,
)

// Tag[T]
case tpr @ AppliedType(_, List(tag)) if isTag(tpr) =>
acc.copy(tags = tag.asType +: acc.tags)

case _ =>
acc
}
}
},
)
}

def schemaOf[R: Type](
Expand Down

0 comments on commit 8008f4e

Please sign in to comment.