Skip to content

Commit

Permalink
Fix generateParser for case classes with more than 22 fields
Browse files Browse the repository at this point in the history
Instead of relying on the apply and unapply methods, list the case class
fields directly.
  • Loading branch information
kevinmehall committed Jun 6, 2023
1 parent 6d082a8 commit da482ed
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,16 @@ class RowParserImpl(val c: Context) {

private def generate[A: c.WeakTypeTag](opts: AnnotOpts): Tree = {
val tpe = weakTypeTag[A].tpe
val theApply = findApply(tpe)
val fields = findCaseClassFields(tpe)

val params = theApply match {
case Some(symbol) => symbol.paramLists.head
case None => c.abort(c.enclosingPosition, "No apply function found")
}

val paramNames = params.map(_.name.toString).toSet
val paramNames = fields.map(_._1.toString).toSet
opts.remapping.foreach { case (givenCol, tree) =>
if (!paramNames.contains(givenCol)) {
c.abort(tree.pos, s"$givenCol is not a member of $tpe")
}
}

val input = generateCalls(params.map(CallData.fromSymbol(_, opts)))
val input = generateCalls(fields, opts)

val comp = q"${tpe.typeSymbol.companion}"
val typeName = tq"${weakTypeTag[A].tpe}"
Expand Down Expand Up @@ -149,30 +144,25 @@ class RowParserImpl(val c: Context) {
}
}

case class CallData(name: Literal, tpt: Type, args: List[Type], isOption: Boolean)
object CallData {
def fromSymbol(sym: Symbol, opts: AnnotOpts): CallData = {
private def generateCalls(fields: List[(TermName, Type)], opts: AnnotOpts): List[Tree] = {
fields.map { case (name, ty) =>
val value = if (opts.snakeCase) {
toSnakeCase(sym.name.toString)
} else if (opts.remapping.contains(sym.name.toString)) {
tupleValueString(opts.remapping(sym.name.toString))
toSnakeCase(name.toString)
} else if (opts.remapping.contains(name.toString)) {
tupleValueString(opts.remapping(name.toString))
} else {
sym.name.toString
name.toString
}

val TypeRef(_, outerType, args) = sym.info
val TypeRef(_, option, _) = typeOf[Option[Any]]
val nameLiteral = Literal(Constant(value))

CallData(Literal(Constant(value)), sym.info, args, outerType == option)
}
}
val TypeRef(_, outerType, args) = ty
val TypeRef(_, option, _) = typeOf[Option[Any]]

private def generateCalls(callData: List[CallData]): List[Tree] = {
callData.map { cd =>
if (cd.isOption) {
q"row.opt[${cd.args.head}](${cd.name})"
if (outerType == option) {
q"row.opt[${args.head}](${nameLiteral})"
} else {
q"row[${cd.tpt}](${cd.name})"
q"row[${ty}](${nameLiteral})"
}
}
}
Expand Down Expand Up @@ -205,85 +195,10 @@ class RowParserImpl(val c: Context) {
"([a-z\\d])([A-Z])", "$1_$2"
).toLowerCase

private def findApply(target: Type): Option[MethodSymbol] = {
val companion: Type = target.companion

val unapplyReturnTypes = getUnapplyReturnTypes(companion)
val applies = getApplies(companion)
findApplyUnapplyMatch(companion, applies, unapplyReturnTypes)
}

private def getReturnTypes(args: List[Type]): Option[List[Type]] = {
args.head match {
case t @ TypeRef(_, _, Nil) => Some(List(t))
case t @ TypeRef(_, _, args) =>
if (t <:< typeOf[Product]) Some(args)
else Some(List(t))
case _ => None
}
}

private def getUnapplyReturnTypes(companion: Type): Option[List[Type]] = {
val unapply = companion.decl(TermName("unapply"))
val unapplySeq = companion.decl(TermName("unapplySeq"))
val hasVarArgs = unapplySeq != NoSymbol

val effectiveUnapply = Seq(unapply, unapplySeq).find(_ != NoSymbol) match {
case None => c.abort(c.enclosingPosition, "No unapply or unapplySeq function found")
case Some(s) => s.asMethod
}

effectiveUnapply.returnType match {
case TypeRef(_, _, Nil) =>
c.abort(c.enclosingPosition, s"Unapply of $companion has no parameters. Are you using an empty case class?")
None

case TypeRef(_, _, args) =>
args.head match {
case t @ TypeRef(_, _, Nil) => Some(List(t))
case t @ TypeRef(_, _, args) =>
import c.universe.definitions.TupleClass
if (!TupleClass.seq.exists(tupleSym => t.baseType(tupleSym) ne NoType)) Some(List(t))
else if (t <:< typeOf[Product]) Some(args)
else None
case _ => None
}
case _ => None
}
}

private def getApplies(companion: Type): List[Symbol] = {
companion.decl(TermName("apply")) match {
case NoSymbol => c.abort(c.enclosingPosition, "No apply function found")
case s => s.asTerm.alternatives
}
}

private def findApplyUnapplyMatch(
companion: Type,
applies: List[Symbol],
unapplyReturnTypes: Option[List[Type]]
): Option[MethodSymbol] = {
val unapply = companion.decl(TermName("unapply"))
val unapplySeq = companion.decl(TermName("unapplySeq"))
val hasVarArgs = unapplySeq != NoSymbol

applies.collectFirst {
case (apply: MethodSymbol) if hasVarArgs && {
val someApplyTypes = apply.paramLists.headOption.map(_.map(_.asTerm.typeSignature))
val someInitApply = someApplyTypes.map(_.init)
val someApplyLast = someApplyTypes.map(_.last)
val someInitUnapply = unapplyReturnTypes.map(_.init)
val someUnapplyLast = unapplyReturnTypes.map(_.last)
val initsMatch = someInitApply == someInitUnapply
val lastMatch = (for {
lastApply <- someApplyLast
lastUnapply <- someUnapplyLast
} yield lastApply <:< lastUnapply).getOrElse(false)
initsMatch && lastMatch
} => apply
case (apply: MethodSymbol) if apply.paramLists.headOption.map(_.map(_.asTerm.typeSignature)) == unapplyReturnTypes => apply
}
private def findCaseClassFields(ty: Type): List[(TermName, Type)] = {
ty.members.sorted.collect {
case m: MethodSymbol if m.isCaseAccessor => (m.name, m.returnType)
}.toList
}

private def expand(colLit: Tree, tree: Tree): (String, Tree) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,38 @@ object Thing {
}
case class User(firstName: String, lastName: String)

case class Big(
f1: Int,
f2: Option[Int],
f3: Int,
f4: Int,
f5: Int,
f6: Int,
f7: Int,
f8: Int,
f9: Int,
z10: Int,
z11: Int,
z12: Int,
z13: Int,
z14: Int,
z15: Int,
z16: Int,
z17: Int,
z18: Int,
z19: Int,
a20: Int,
a21: Int,
a22: Int,
a23: Int,
a24: Option[Int],
a25: Int
) {
val m1: Int = 0
def m2: Int = 0
}


class RowParserTest extends Specification with Mockito {
"RowParser def macros" should {
"generate parser" in {
Expand Down Expand Up @@ -76,6 +108,23 @@ class RowParserTest extends Specification with Mockito {
p.parse(row) mustEqual User("gregg", "hernandez")
}

"generate parser for a case class > 22 fields" in {
val rs = mock[java.sql.ResultSet]
for (i <- (1 to 9)) { rs.getInt(s"f${i}") returns i }
for (i <- (10 to 19)) { rs.getInt(s"z${i}") returns i }
for (i <- (20 to 25)) { rs.getInt(s"a${i}") returns i }

val row = SqlRow(rs)

val p = generateParser[Big]

p.parse(row) mustEqual(Big(
1, Some(2), 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, Some(24), 25)
)
}

"fail to compile with non-literals" in {
illTyped(
"""val name = "newName"; generateParser[User](Map("firstName" -> name))""",
Expand Down

0 comments on commit da482ed

Please sign in to comment.