diff --git a/macros/src/test/scala/com/lucidchart/open/relate/macros/RowParserTest.scala b/macros/src/test/scala/com/lucidchart/open/relate/macros/RowParserTest.scala index 59c4dda..dc7f8a0 100644 --- a/macros/src/test/scala/com/lucidchart/open/relate/macros/RowParserTest.scala +++ b/macros/src/test/scala/com/lucidchart/open/relate/macros/RowParserTest.scala @@ -44,12 +44,16 @@ case class Big( class RowParserTest extends Specification with Mockito { + class MockableRow extends SqlRow(null) { + final override def apply[A: ColReader](col: String): A = super.apply(col) + final override def opt[A: ColReader](col: String): Option[A] = super.opt(col) + } + "RowParser def macros" should { "generate parser" in { - val rs = mock[java.sql.ResultSet] - rs.getString("firstName") returns "hi" - rs.getInt("b") returns 20 - val row = SqlRow(rs) + val row = mock[MockableRow] + row.stringOption("firstName") returns Some("hi") + row.intOption("b") returns Some(20) val p = generateParser[Thing] @@ -57,10 +61,9 @@ class RowParserTest extends Specification with Mockito { } "generate parser w/snake_case columns" in { - val rs = mock[java.sql.ResultSet] - rs.getString("first_name") returns "gregg" - rs.getInt("b") returns 20 - val row = SqlRow(rs) + val row = mock[MockableRow] + row.stringOption("first_name") returns Some("gregg") + row.intOption("b") returns Some(20) val p = generateSnakeParser[Thing] @@ -73,10 +76,9 @@ class RowParserTest extends Specification with Mockito { "lastName" -> "lname" )) - val rs = mock[java.sql.ResultSet] - rs.getString("fname") returns "gregg" - rs.getString("lname") returns "hernandez" - val row = SqlRow(rs) + val row = mock[MockableRow] + row.stringOption("fname") returns Some("gregg") + row.stringOption("lname") returns Some("hernandez") p.parse(row) mustEqual User("gregg", "hernandez") } @@ -87,10 +89,9 @@ class RowParserTest extends Specification with Mockito { ("lastName", "lname") )) - val rs = mock[java.sql.ResultSet] - rs.getString("fname") returns "gregg" - rs.getString("lname") returns "hernandez" - val row = SqlRow(rs) + val row = mock[MockableRow] + row.stringOption("fname") returns Some("gregg") + row.stringOption("lname") returns Some("hernandez") p.parse(row) mustEqual User("gregg", "hernandez") } @@ -100,21 +101,19 @@ class RowParserTest extends Specification with Mockito { "firstName" -> "fname" )) - val rs = mock[java.sql.ResultSet] - rs.getString("fname") returns "gregg" - rs.getString("lastName") returns "hernandez" - val row = SqlRow(rs) + val row = mock[MockableRow] + row.stringOption("fname") returns Some("gregg") + row.stringOption("lastName") returns Some("hernandez") p.parse(row) mustEqual User("gregg", "hernandez") } "generate parser for a case class > 22 fields" in { - val rs = mock[java.sql.ResultSet] - for (i <- (1 to 9)) { rs.getInt(s"f${i}") returns i } - for (i <- (10 to 19)) { rs.getInt(s"z${i}") returns i } - for (i <- (20 to 25)) { rs.getInt(s"a${i}") returns i } + val row = mock[MockableRow] + for (i <- (1 to 9)) { row.intOption(s"f${i}") returns Some(i) } + for (i <- (10 to 19)) { row.intOption(s"z${i}") returns Some(i) } + for (i <- (20 to 25)) { row.intOption(s"a${i}") returns Some(i) } - val row = SqlRow(rs) val p = generateParser[Big] @@ -154,10 +153,9 @@ class RowParserTest extends Specification with Mockito { case class SimpleRecord(firstName: String, lastName: Option[String]) "generate parser" in { - val rs = mock[java.sql.ResultSet] - rs.getString("firstName") returns "gregg" - rs.getString("lastName") returns "hernandez" - val row = SqlRow(rs) + val row = mock[MockableRow] + row.stringOption("firstName") returns Some("gregg") + row.stringOption("lastName") returns Some("hernandez") implicitly[RowParser[SimpleRecord]].parse(row) mustEqual SimpleRecord("gregg", Some("hernandez")) } @@ -170,10 +168,9 @@ class RowParserTest extends Specification with Mockito { } "generate parser w/snake_case columns" in { - val rs = mock[java.sql.ResultSet] - rs.getString("first_name") returns "gregg" - rs.getString("last_name") returns "hernandez" - val row = SqlRow(rs) + val row = mock[MockableRow] + row.stringOption("first_name") returns Some("gregg") + row.stringOption("last_name") returns Some("hernandez") // verify that this still compiles SnakeRecord.f() @@ -185,10 +182,9 @@ class RowParserTest extends Specification with Mockito { case class RemapRecord(firstName: String, lastName: String) "remap column names" in { - val rs = mock[java.sql.ResultSet] - rs.getString("fname") returns "gregg" - rs.getString("lname") returns "hernandez" - val row = SqlRow(rs) + val row = mock[MockableRow] + row.stringOption("fname") returns Some("gregg") + row.stringOption("lname") returns Some("hernandez") implicitly[RowParser[RemapRecord]].parse(row) mustEqual RemapRecord("gregg", "hernandez") } @@ -197,10 +193,9 @@ class RowParserTest extends Specification with Mockito { case class RemapTRecord(firstName: String, lastName: String) "remap column names w/normal tuple syntax" in { - val rs = mock[java.sql.ResultSet] - rs.getString("fname") returns "gregg" - rs.getString("lname") returns "hernandez" - val row = SqlRow(rs) + val row = mock[MockableRow] + row.stringOption("fname") returns Some("gregg") + row.stringOption("lname") returns Some("hernandez") implicitly[RowParser[RemapTRecord]].parse(row) mustEqual RemapTRecord("gregg", "hernandez") } @@ -209,10 +204,9 @@ class RowParserTest extends Specification with Mockito { case class RemapSomeRecord(firstName: String, lastName: String) "remap some column names" in { - val rs = mock[java.sql.ResultSet] - rs.getString("fname") returns "gregg" - rs.getString("lastName") returns "hernandez" - val row = SqlRow(rs) + val row = mock[MockableRow] + row.stringOption("fname") returns Some("gregg") + row.stringOption("lastName") returns Some("hernandez") implicitly[RowParser[RemapSomeRecord]].parse(row) mustEqual RemapSomeRecord("gregg", "hernandez") } diff --git a/relate/src/main/scala/com/lucidchart/relate/ColReader.scala b/relate/src/main/scala/com/lucidchart/relate/ColReader.scala index 4276c4c..5745c4f 100644 --- a/relate/src/main/scala/com/lucidchart/relate/ColReader.scala +++ b/relate/src/main/scala/com/lucidchart/relate/ColReader.scala @@ -1,7 +1,5 @@ package com.lucidchart.relate -import java.nio.ByteBuffer -import java.sql.ResultSet import java.time.Instant import java.util.{Date, UUID} @@ -36,39 +34,21 @@ object ColReader { def read(col: String, rs: SqlRow): Option[A] = f(col, rs) } - def option[A](x: A, rs: ResultSet): Option[A] = { - if (rs.wasNull()) { - None - } else { - Some(x) - } - } - - private def optReader[A](f: (String, ResultSet) => A): ColReader[A] = ColReader[A] { (col, row) => - option(f(col, row.resultSet), row.resultSet) - } - - implicit val jbigDecimalReader: ColReader[java.math.BigDecimal] = ColReader[java.math.BigDecimal] { (col, row) => - option(row.resultSet.getBigDecimal(col), row.resultSet) - } - - implicit val bigDecimalReader: ColReader[BigDecimal] = jbigDecimalReader.map(BigDecimal(_)) - - implicit val bigIntReader: ColReader[BigInt] = jbigDecimalReader.map(bd => BigInt(bd.longValue)) - - implicit val boolReader: ColReader[Boolean] = optReader((col, rs) => rs.getBoolean(col)) - implicit val byteArrayReader: ColReader[Array[Byte]] = optReader((col, rs) => rs.getBytes(col)) - implicit val byteReader: ColReader[Byte] = optReader((col, rs) => rs.getByte(col)) - implicit val dateReader: ColReader[Date] = optReader((col, rs) => rs.getDate(col)) - implicit val instantReader: ColReader[Instant] = optReader((col, rs) => rs.getTimestamp(col)).map(_.toInstant) - implicit val doubleReader: ColReader[Double] = optReader((col, rs) => rs.getDouble(col)) - implicit val intReader: ColReader[Int] = optReader((col, rs) => rs.getInt(col)) - implicit val longReader: ColReader[Long] = optReader((col, rs) => rs.getLong(col)) - implicit val shortReader: ColReader[Short] = optReader((col, rs) => rs.getShort(col)) - implicit val stringReader: ColReader[String] = optReader((col, rs) => rs.getString(col)) - implicit val uuidReader: ColReader[UUID] = ColReader[UUID] { (col, row) => - row.uuidOption(col) - } + implicit val jbigDecimalReader: ColReader[java.math.BigDecimal] = ColReader { (col, row) => row.javaBigDecimalOption(col)} + implicit val bigDecimalReader: ColReader[BigDecimal] = ColReader { (col, row) => row.bigDecimalOption(col)} + implicit val jBigIntReader: ColReader[java.math.BigInteger] = ColReader { (col, row) => row.javaBigIntegerOption(col)} + implicit val bigIntReader: ColReader[BigInt] = ColReader { (col, row) => row.bigIntOption(col)} + implicit val boolReader: ColReader[Boolean] = ColReader { (col, row) => row.boolOption(col)} + implicit val byteArrayReader: ColReader[Array[Byte]] = ColReader { (col, row) => row.byteArrayOption(col)} + implicit val byteReader: ColReader[Byte] = ColReader { (col, row) => row.byteOption(col)} + implicit val dateReader: ColReader[Date] = ColReader { (col, row) => row.dateOption(col)} + implicit val instantReader: ColReader[Instant] = ColReader { (col, row) => row.instantOption(col)} + implicit val doubleReader: ColReader[Double] = ColReader { (col, row) => row.doubleOption(col)} + implicit val intReader: ColReader[Int] = ColReader { (col, row) => row.intOption(col)} + implicit val longReader: ColReader[Long] = ColReader { (col, row) => row.longOption(col)} + implicit val shortReader: ColReader[Short] = ColReader { (col, row) => row.shortOption(col)} + implicit val stringReader: ColReader[String] = ColReader { (col, row) => row.stringOption(col)} + implicit val uuidReader: ColReader[UUID] = ColReader[UUID] { (col, row) => row.uuidOption(col)} def enumReader[A <: Enumeration](e: A): ColReader[e.Value] = { intReader.flatMap(id => ColReader[e.Value] { (_, _) => diff --git a/relate/src/test/scala/ColReaderTest.scala b/relate/src/test/scala/ColReaderTest.scala index 4a39dab..e08e3d4 100644 --- a/relate/src/test/scala/ColReaderTest.scala +++ b/relate/src/test/scala/ColReaderTest.scala @@ -1,12 +1,16 @@ package com.lucidchart.relate -import java.util.{Date, UUID} -import java.time.Instant import org.specs2.mock.Mockito import org.specs2.mutable._ +import java.time.Instant +import java.util.{Date, UUID} + case class RecordA( bd: BigDecimal, + bigInt: BigInt, + jBigDecimal: java.math.BigDecimal, + jBigInt: java.math.BigInteger, bool: Boolean, ba: Array[Byte], byte: Byte, @@ -21,11 +25,22 @@ case class RecordA( thing: Things.Value ) +/** + * Ensures that the real apply/opt methods are called (whereas `mock[SqlRow]` would 'null' out that method + */ +class MockableRow extends SqlRow(null) { + final override def apply[A: ColReader](col: String): A = super.apply(col) + final override def opt[A: ColReader](col: String): Option[A] = super.opt(col) +} + object RecordA extends Mockito { implicit val reader = new RowParser[RecordA] { def parse(row: SqlRow): RecordA = { RecordA( row[BigDecimal]("bd"), + row[BigInt]("bi"), + row[java.math.BigDecimal]("jbd"), + row[java.math.BigInteger]("jbi"), row[Boolean]("bool"), row[Array[Byte]]("ba"), row[Byte]("byte"), @@ -42,30 +57,37 @@ object RecordA extends Mockito { } } - val timeMillis: Long = 1576179411000l + val timeMillis: Long = 1576179411000L + val uuid: UUID = UUID.fromString("01020304-0506-0708-090a-0b0c0d0e0f10") val mockRow = { - val rs = mock[java.sql.ResultSet] - rs.getBigDecimal("bd") returns new java.math.BigDecimal(10) - rs.getBoolean("bool") returns true - rs.getBytes("ba") returns Array[Byte](1,2,3) - rs.getByte("byte") returns (1: Byte) - rs.getDate("date") returns (new java.sql.Date(10000)) - rs.getTimestamp("instant") returns (new java.sql.Timestamp(timeMillis)) - rs.getDouble("double") returns 1.1 - rs.getInt("int") returns 10 - rs.getLong("long") returns 100L - rs.getShort("short") returns (5: Short) - rs.getString("str") returns "hello" - rs.getObject("uuid") returns Array[Byte](1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16) - rs.getInt("thing") returns 1 - SqlRow(rs) + val row = mock[MockableRow] + row.bigDecimalOption("bd") returns Some(BigDecimal(10)) + row.bigIntOption("bi") returns Some(BigInt(10)) + row.javaBigDecimalOption("jbd") returns Some(new java.math.BigDecimal(10)) + row.javaBigIntegerOption("jbi") returns Some(java.math.BigInteger.valueOf(10)) + row.boolOption("bool") returns Some(true) + row.byteArrayOption("ba") returns Some(Array[Byte](1,2,3)) + row.byteOption("byte") returns Some((1: Byte)) + row.dateOption("date") returns Some((new Date(timeMillis))) + row.instantOption("instant") returns Some(Instant.ofEpochMilli(timeMillis)) + row.doubleOption("double") returns Some(1.1) + row.intOption("int") returns Some(10) + row.longOption("long") returns Some(100L) + row.shortOption("short") returns Some((5: Short)) + row.stringOption("str") returns Some("hello") + row.uuidOption("uuid") returns Some(uuid) + row.intOption("thing") returns Some(1) + row } } case class RecordB( bd: Option[BigDecimal], + bigInt: Option[BigInt], + jBigDecimal: Option[java.math.BigDecimal], + jBigInt: Option[java.math.BigInteger], bool: Option[Boolean], ba: Option[Array[Byte]], byte: Option[Byte], @@ -85,6 +107,9 @@ object RecordB extends Mockito { def parse(row: SqlRow): RecordB = { RecordB( row.opt[BigDecimal]("bd"), + row.opt[BigInt]("bi"), + row.opt[java.math.BigDecimal]("jbd"), + row.opt[java.math.BigInteger]("jbi"), row.opt[Boolean]("bool"), row.opt[Array[Byte]]("ba"), row.opt[Byte]("byte"), @@ -102,14 +127,24 @@ object RecordB extends Mockito { } val mockRow = { - val rs = mock[java.sql.ResultSet] - rs.wasNull() returns true - rs.getBigDecimal("bd") returns null - rs.getBytes("ba") returns null - rs.getDate("date") returns null - rs.getString("str") returns null - rs.getBytes("uuid") returns null - SqlRow(rs) + val row = mock[MockableRow] + row.bigDecimalOption("bd") returns None + row.bigIntOption("bi") returns None + row.javaBigDecimalOption("jbd") returns None + row.javaBigIntegerOption("jbi") returns None + row.boolOption("bool") returns None + row.byteArrayOption("ba") returns None + row.byteOption("byte") returns None + row.dateOption("date") returns None + row.instantOption("instant") returns None + row.doubleOption("double") returns None + row.intOption("int") returns None + row.longOption("long") returns None + row.shortOption("short") returns None + row.stringOption("str") returns None + row.uuidOption("uuid") returns None + row.intOption("thing") returns None + row } } @@ -123,30 +158,34 @@ object Things extends Enumeration { class ColReaderTest extends Specification with Mockito { val mockedInstant = Instant.EPOCH.plusMillis(RecordA.timeMillis) + "ColReader" should { "parse a present values" in { val row = RecordA.mockRow val parsed = RecordA.reader.parse(row) // Arrays use reference equality so we have to check this - // independantly of all the other values + // independently of all the other values val bytes = parsed.ba bytes === Array[Byte](1,2,3) parsed.copy(ba = null) mustEqual RecordA( - BigDecimal(10), - true, - null, - 1, - new Date(10000), - mockedInstant, - 1.1, - 10, - 100, - 5, - "hello", - UUID.fromString("01020304-0506-0708-090a-0b0c0d0e0f10"), - Things.One + bd = BigDecimal(10), + bigInt = BigInt(10), + jBigDecimal = new java.math.BigDecimal(10), + jBigInt = java.math.BigInteger.valueOf(10), + bool = true, + ba = null, + byte = 1, + date = new Date(mockedInstant.toEpochMilli), + instant = mockedInstant, + double = 1.1, + int = 10, + long = 100, + short = 5, + str = "hello", + uuid = RecordA.uuid, + thing = Things.One ) } @@ -155,19 +194,22 @@ class ColReaderTest extends Specification with Mockito { val parsed = RecordB.reader.parse(row) parsed mustEqual RecordB( - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None + bd = None, + bigInt = None, + jBigDecimal = None, + jBigInt = None, + bool = None, + ba = None, + byte = None, + date = None, + instant = None, + double = None, + int = None, + long = None, + short = None, + str = None, + uuid = None, + thing = None ) } @@ -177,23 +219,26 @@ class ColReaderTest extends Specification with Mockito { // Arrays use reference equality so we have to check this // independantly of all the other values - val bytes = parsed.ba.get + val bytes: Array[Byte] = parsed.ba.get bytes === Array[Byte](1,2,3) - parsed.copy(ba = null) mustEqual RecordB( - Some(BigDecimal(10)), - Some(true), - null, - Some(1), - Some(new Date(10000)), - Some(mockedInstant), - Some(1.1), - Some(10), - Some(100), - Some(5), - Some("hello"), - Some(UUID.fromString("01020304-0506-0708-090a-0b0c0d0e0f10")), - Some(Things.One) + parsed.copy(ba = None) mustEqual RecordB( + bd = Some(BigDecimal(10)), + bigInt = Some(BigInt(10)), + jBigDecimal = Some(new java.math.BigDecimal(10)), + jBigInt = Some(java.math.BigInteger.valueOf(10)), + bool = Some(true), + ba = None, + byte = Some(1), + date = Some(new Date(mockedInstant.toEpochMilli)), + instant = Some(mockedInstant), + double = Some(1.1), + int = Some(10), + long = Some(100), + short = Some(5), + str = Some("hello"), + uuid = Some(UUID.fromString("01020304-0506-0708-090a-0b0c0d0e0f10")), + thing = Some(Things.One) ) } }