diff --git a/macros/src/main/scala/com/lucidchart/open/relate/macros/RowParserImpl.scala b/macros/src/main/scala/com/lucidchart/open/relate/macros/RowParserImpl.scala index eac3cfe..be66e3a 100644 --- a/macros/src/main/scala/com/lucidchart/open/relate/macros/RowParserImpl.scala +++ b/macros/src/main/scala/com/lucidchart/open/relate/macros/RowParserImpl.scala @@ -105,21 +105,16 @@ class RowParserImpl(val c: Context) { private def generate[A: c.WeakTypeTag](opts: AnnotOpts): Tree = { val tpe = weakTypeTag[A].tpe - val theApply = findApply(tpe) + val fields = findCaseClassFields(tpe) - val params = theApply match { - case Some(symbol) => symbol.paramLists.head - case None => c.abort(c.enclosingPosition, "No apply function found") - } - - val paramNames = params.map(_.name.toString).toSet + val paramNames = fields.map(_._1.toString).toSet opts.remapping.foreach { case (givenCol, tree) => if (!paramNames.contains(givenCol)) { c.abort(tree.pos, s"$givenCol is not a member of $tpe") } } - val input = generateCalls(params.map(CallData.fromSymbol(_, opts))) + val input = generateCalls(fields, opts) val comp = q"${tpe.typeSymbol.companion}" val typeName = tq"${weakTypeTag[A].tpe}" @@ -149,30 +144,25 @@ class RowParserImpl(val c: Context) { } } - case class CallData(name: Literal, tpt: Type, args: List[Type], isOption: Boolean) - object CallData { - def fromSymbol(sym: Symbol, opts: AnnotOpts): CallData = { + private def generateCalls(fields: List[(TermName, Type)], opts: AnnotOpts): List[Tree] = { + fields.map { case (name, ty) => val value = if (opts.snakeCase) { - toSnakeCase(sym.name.toString) - } else if (opts.remapping.contains(sym.name.toString)) { - tupleValueString(opts.remapping(sym.name.toString)) + toSnakeCase(name.toString) + } else if (opts.remapping.contains(name.toString)) { + tupleValueString(opts.remapping(name.toString)) } else { - sym.name.toString + name.toString } - val TypeRef(_, outerType, args) = sym.info - val TypeRef(_, option, _) = typeOf[Option[Any]] + val nameLiteral = Literal(Constant(value)) - CallData(Literal(Constant(value)), sym.info, args, outerType == option) - } - } + val TypeRef(_, outerType, args) = ty + val TypeRef(_, option, _) = typeOf[Option[Any]] - private def generateCalls(callData: List[CallData]): List[Tree] = { - callData.map { cd => - if (cd.isOption) { - q"row.opt[${cd.args.head}](${cd.name})" + if (outerType == option) { + q"row.opt[${args.head}](${nameLiteral})" } else { - q"row[${cd.tpt}](${cd.name})" + q"row[${ty}](${nameLiteral})" } } } @@ -205,85 +195,10 @@ class RowParserImpl(val c: Context) { "([a-z\\d])([A-Z])", "$1_$2" ).toLowerCase - private def findApply(target: Type): Option[MethodSymbol] = { - val companion: Type = target.companion - - val unapplyReturnTypes = getUnapplyReturnTypes(companion) - val applies = getApplies(companion) - findApplyUnapplyMatch(companion, applies, unapplyReturnTypes) - } - - private def getReturnTypes(args: List[Type]): Option[List[Type]] = { - args.head match { - case t @ TypeRef(_, _, Nil) => Some(List(t)) - case t @ TypeRef(_, _, args) => - if (t <:< typeOf[Product]) Some(args) - else Some(List(t)) - case _ => None - } - } - - private def getUnapplyReturnTypes(companion: Type): Option[List[Type]] = { - val unapply = companion.decl(TermName("unapply")) - val unapplySeq = companion.decl(TermName("unapplySeq")) - val hasVarArgs = unapplySeq != NoSymbol - - val effectiveUnapply = Seq(unapply, unapplySeq).find(_ != NoSymbol) match { - case None => c.abort(c.enclosingPosition, "No unapply or unapplySeq function found") - case Some(s) => s.asMethod - } - - effectiveUnapply.returnType match { - case TypeRef(_, _, Nil) => - c.abort(c.enclosingPosition, s"Unapply of $companion has no parameters. Are you using an empty case class?") - None - - case TypeRef(_, _, args) => - args.head match { - case t @ TypeRef(_, _, Nil) => Some(List(t)) - case t @ TypeRef(_, _, args) => - import c.universe.definitions.TupleClass - if (!TupleClass.seq.exists(tupleSym => t.baseType(tupleSym) ne NoType)) Some(List(t)) - else if (t <:< typeOf[Product]) Some(args) - else None - case _ => None - } - case _ => None - } - } - - private def getApplies(companion: Type): List[Symbol] = { - companion.decl(TermName("apply")) match { - case NoSymbol => c.abort(c.enclosingPosition, "No apply function found") - case s => s.asTerm.alternatives - } - } - - private def findApplyUnapplyMatch( - companion: Type, - applies: List[Symbol], - unapplyReturnTypes: Option[List[Type]] - ): Option[MethodSymbol] = { - val unapply = companion.decl(TermName("unapply")) - val unapplySeq = companion.decl(TermName("unapplySeq")) - val hasVarArgs = unapplySeq != NoSymbol - - applies.collectFirst { - case (apply: MethodSymbol) if hasVarArgs && { - val someApplyTypes = apply.paramLists.headOption.map(_.map(_.asTerm.typeSignature)) - val someInitApply = someApplyTypes.map(_.init) - val someApplyLast = someApplyTypes.map(_.last) - val someInitUnapply = unapplyReturnTypes.map(_.init) - val someUnapplyLast = unapplyReturnTypes.map(_.last) - val initsMatch = someInitApply == someInitUnapply - val lastMatch = (for { - lastApply <- someApplyLast - lastUnapply <- someUnapplyLast - } yield lastApply <:< lastUnapply).getOrElse(false) - initsMatch && lastMatch - } => apply - case (apply: MethodSymbol) if apply.paramLists.headOption.map(_.map(_.asTerm.typeSignature)) == unapplyReturnTypes => apply - } + private def findCaseClassFields(ty: Type): List[(TermName, Type)] = { + ty.members.sorted.collect { + case m: MethodSymbol if m.isCaseAccessor => (m.name, m.returnType) + }.toList } private def expand(colLit: Tree, tree: Tree): (String, Tree) = { diff --git a/macros/src/test/scala/com/lucidchart/open/relate/macros/RowParserTest.scala b/macros/src/test/scala/com/lucidchart/open/relate/macros/RowParserTest.scala index 5077e94..59c4dda 100644 --- a/macros/src/test/scala/com/lucidchart/open/relate/macros/RowParserTest.scala +++ b/macros/src/test/scala/com/lucidchart/open/relate/macros/RowParserTest.scala @@ -11,6 +11,38 @@ object Thing { } case class User(firstName: String, lastName: String) +case class Big( + f1: Int, + f2: Option[Int], + f3: Int, + f4: Int, + f5: Int, + f6: Int, + f7: Int, + f8: Int, + f9: Int, + z10: Int, + z11: Int, + z12: Int, + z13: Int, + z14: Int, + z15: Int, + z16: Int, + z17: Int, + z18: Int, + z19: Int, + a20: Int, + a21: Int, + a22: Int, + a23: Int, + a24: Option[Int], + a25: Int +) { + val m1: Int = 0 + def m2: Int = 0 +} + + class RowParserTest extends Specification with Mockito { "RowParser def macros" should { "generate parser" in { @@ -76,6 +108,23 @@ class RowParserTest extends Specification with Mockito { p.parse(row) mustEqual User("gregg", "hernandez") } + "generate parser for a case class > 22 fields" in { + val rs = mock[java.sql.ResultSet] + for (i <- (1 to 9)) { rs.getInt(s"f${i}") returns i } + for (i <- (10 to 19)) { rs.getInt(s"z${i}") returns i } + for (i <- (20 to 25)) { rs.getInt(s"a${i}") returns i } + + val row = SqlRow(rs) + + val p = generateParser[Big] + + p.parse(row) mustEqual(Big( + 1, Some(2), 3, 4, 5, 6, 7, 8, 9, 10, + 11, 12, 13, 14, 15, 16, 17, 18, 19, + 20, 21, 22, 23, Some(24), 25) + ) + } + "fail to compile with non-literals" in { illTyped( """val name = "newName"; generateParser[User](Map("firstName" -> name))""",