Skip to content

Commit

Permalink
Merge pull request #80 from kevinmehall/macro
Browse files Browse the repository at this point in the history
Fix generateParser macro for case classes with more than 22 fields
  • Loading branch information
tmccombs authored Jun 8, 2023
2 parents 6d082a8 + da482ed commit 2964822
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -105,21 +105,16 @@ class RowParserImpl(val c: Context) {

private def generate[A: c.WeakTypeTag](opts: AnnotOpts): Tree = {
val tpe = weakTypeTag[A].tpe
val theApply = findApply(tpe)
val fields = findCaseClassFields(tpe)

val params = theApply match {
case Some(symbol) => symbol.paramLists.head
case None => c.abort(c.enclosingPosition, "No apply function found")
}

val paramNames = params.map(_.name.toString).toSet
val paramNames = fields.map(_._1.toString).toSet
opts.remapping.foreach { case (givenCol, tree) =>
if (!paramNames.contains(givenCol)) {
c.abort(tree.pos, s"$givenCol is not a member of $tpe")
}
}

val input = generateCalls(params.map(CallData.fromSymbol(_, opts)))
val input = generateCalls(fields, opts)

val comp = q"${tpe.typeSymbol.companion}"
val typeName = tq"${weakTypeTag[A].tpe}"
Expand Down Expand Up @@ -149,30 +144,25 @@ class RowParserImpl(val c: Context) {
}
}

case class CallData(name: Literal, tpt: Type, args: List[Type], isOption: Boolean)
object CallData {
def fromSymbol(sym: Symbol, opts: AnnotOpts): CallData = {
private def generateCalls(fields: List[(TermName, Type)], opts: AnnotOpts): List[Tree] = {
fields.map { case (name, ty) =>
val value = if (opts.snakeCase) {
toSnakeCase(sym.name.toString)
} else if (opts.remapping.contains(sym.name.toString)) {
tupleValueString(opts.remapping(sym.name.toString))
toSnakeCase(name.toString)
} else if (opts.remapping.contains(name.toString)) {
tupleValueString(opts.remapping(name.toString))
} else {
sym.name.toString
name.toString
}

val TypeRef(_, outerType, args) = sym.info
val TypeRef(_, option, _) = typeOf[Option[Any]]
val nameLiteral = Literal(Constant(value))

CallData(Literal(Constant(value)), sym.info, args, outerType == option)
}
}
val TypeRef(_, outerType, args) = ty
val TypeRef(_, option, _) = typeOf[Option[Any]]

private def generateCalls(callData: List[CallData]): List[Tree] = {
callData.map { cd =>
if (cd.isOption) {
q"row.opt[${cd.args.head}](${cd.name})"
if (outerType == option) {
q"row.opt[${args.head}](${nameLiteral})"
} else {
q"row[${cd.tpt}](${cd.name})"
q"row[${ty}](${nameLiteral})"
}
}
}
Expand Down Expand Up @@ -205,85 +195,10 @@ class RowParserImpl(val c: Context) {
"([a-z\\d])([A-Z])", "$1_$2"
).toLowerCase

private def findApply(target: Type): Option[MethodSymbol] = {
val companion: Type = target.companion

val unapplyReturnTypes = getUnapplyReturnTypes(companion)
val applies = getApplies(companion)
findApplyUnapplyMatch(companion, applies, unapplyReturnTypes)
}

private def getReturnTypes(args: List[Type]): Option[List[Type]] = {
args.head match {
case t @ TypeRef(_, _, Nil) => Some(List(t))
case t @ TypeRef(_, _, args) =>
if (t <:< typeOf[Product]) Some(args)
else Some(List(t))
case _ => None
}
}

private def getUnapplyReturnTypes(companion: Type): Option[List[Type]] = {
val unapply = companion.decl(TermName("unapply"))
val unapplySeq = companion.decl(TermName("unapplySeq"))
val hasVarArgs = unapplySeq != NoSymbol

val effectiveUnapply = Seq(unapply, unapplySeq).find(_ != NoSymbol) match {
case None => c.abort(c.enclosingPosition, "No unapply or unapplySeq function found")
case Some(s) => s.asMethod
}

effectiveUnapply.returnType match {
case TypeRef(_, _, Nil) =>
c.abort(c.enclosingPosition, s"Unapply of $companion has no parameters. Are you using an empty case class?")
None

case TypeRef(_, _, args) =>
args.head match {
case t @ TypeRef(_, _, Nil) => Some(List(t))
case t @ TypeRef(_, _, args) =>
import c.universe.definitions.TupleClass
if (!TupleClass.seq.exists(tupleSym => t.baseType(tupleSym) ne NoType)) Some(List(t))
else if (t <:< typeOf[Product]) Some(args)
else None
case _ => None
}
case _ => None
}
}

private def getApplies(companion: Type): List[Symbol] = {
companion.decl(TermName("apply")) match {
case NoSymbol => c.abort(c.enclosingPosition, "No apply function found")
case s => s.asTerm.alternatives
}
}

private def findApplyUnapplyMatch(
companion: Type,
applies: List[Symbol],
unapplyReturnTypes: Option[List[Type]]
): Option[MethodSymbol] = {
val unapply = companion.decl(TermName("unapply"))
val unapplySeq = companion.decl(TermName("unapplySeq"))
val hasVarArgs = unapplySeq != NoSymbol

applies.collectFirst {
case (apply: MethodSymbol) if hasVarArgs && {
val someApplyTypes = apply.paramLists.headOption.map(_.map(_.asTerm.typeSignature))
val someInitApply = someApplyTypes.map(_.init)
val someApplyLast = someApplyTypes.map(_.last)
val someInitUnapply = unapplyReturnTypes.map(_.init)
val someUnapplyLast = unapplyReturnTypes.map(_.last)
val initsMatch = someInitApply == someInitUnapply
val lastMatch = (for {
lastApply <- someApplyLast
lastUnapply <- someUnapplyLast
} yield lastApply <:< lastUnapply).getOrElse(false)
initsMatch && lastMatch
} => apply
case (apply: MethodSymbol) if apply.paramLists.headOption.map(_.map(_.asTerm.typeSignature)) == unapplyReturnTypes => apply
}
private def findCaseClassFields(ty: Type): List[(TermName, Type)] = {
ty.members.sorted.collect {
case m: MethodSymbol if m.isCaseAccessor => (m.name, m.returnType)
}.toList
}

private def expand(colLit: Tree, tree: Tree): (String, Tree) = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,38 @@ object Thing {
}
case class User(firstName: String, lastName: String)

case class Big(
f1: Int,
f2: Option[Int],
f3: Int,
f4: Int,
f5: Int,
f6: Int,
f7: Int,
f8: Int,
f9: Int,
z10: Int,
z11: Int,
z12: Int,
z13: Int,
z14: Int,
z15: Int,
z16: Int,
z17: Int,
z18: Int,
z19: Int,
a20: Int,
a21: Int,
a22: Int,
a23: Int,
a24: Option[Int],
a25: Int
) {
val m1: Int = 0
def m2: Int = 0
}


class RowParserTest extends Specification with Mockito {
"RowParser def macros" should {
"generate parser" in {
Expand Down Expand Up @@ -76,6 +108,23 @@ class RowParserTest extends Specification with Mockito {
p.parse(row) mustEqual User("gregg", "hernandez")
}

"generate parser for a case class > 22 fields" in {
val rs = mock[java.sql.ResultSet]
for (i <- (1 to 9)) { rs.getInt(s"f${i}") returns i }
for (i <- (10 to 19)) { rs.getInt(s"z${i}") returns i }
for (i <- (20 to 25)) { rs.getInt(s"a${i}") returns i }

val row = SqlRow(rs)

val p = generateParser[Big]

p.parse(row) mustEqual(Big(
1, Some(2), 3, 4, 5, 6, 7, 8, 9, 10,
11, 12, 13, 14, 15, 16, 17, 18, 19,
20, 21, 22, 23, Some(24), 25)
)
}

"fail to compile with non-literals" in {
illTyped(
"""val name = "newName"; generateParser[User](Map("firstName" -> name))""",
Expand Down

0 comments on commit 2964822

Please sign in to comment.