diff --git a/relate/src/main/scala/com/lucidchart/relate/ColReader.scala b/relate/src/main/scala/com/lucidchart/relate/ColReader.scala index 4276c4c..5745c4f 100644 --- a/relate/src/main/scala/com/lucidchart/relate/ColReader.scala +++ b/relate/src/main/scala/com/lucidchart/relate/ColReader.scala @@ -1,7 +1,5 @@ package com.lucidchart.relate -import java.nio.ByteBuffer -import java.sql.ResultSet import java.time.Instant import java.util.{Date, UUID} @@ -36,39 +34,21 @@ object ColReader { def read(col: String, rs: SqlRow): Option[A] = f(col, rs) } - def option[A](x: A, rs: ResultSet): Option[A] = { - if (rs.wasNull()) { - None - } else { - Some(x) - } - } - - private def optReader[A](f: (String, ResultSet) => A): ColReader[A] = ColReader[A] { (col, row) => - option(f(col, row.resultSet), row.resultSet) - } - - implicit val jbigDecimalReader: ColReader[java.math.BigDecimal] = ColReader[java.math.BigDecimal] { (col, row) => - option(row.resultSet.getBigDecimal(col), row.resultSet) - } - - implicit val bigDecimalReader: ColReader[BigDecimal] = jbigDecimalReader.map(BigDecimal(_)) - - implicit val bigIntReader: ColReader[BigInt] = jbigDecimalReader.map(bd => BigInt(bd.longValue)) - - implicit val boolReader: ColReader[Boolean] = optReader((col, rs) => rs.getBoolean(col)) - implicit val byteArrayReader: ColReader[Array[Byte]] = optReader((col, rs) => rs.getBytes(col)) - implicit val byteReader: ColReader[Byte] = optReader((col, rs) => rs.getByte(col)) - implicit val dateReader: ColReader[Date] = optReader((col, rs) => rs.getDate(col)) - implicit val instantReader: ColReader[Instant] = optReader((col, rs) => rs.getTimestamp(col)).map(_.toInstant) - implicit val doubleReader: ColReader[Double] = optReader((col, rs) => rs.getDouble(col)) - implicit val intReader: ColReader[Int] = optReader((col, rs) => rs.getInt(col)) - implicit val longReader: ColReader[Long] = optReader((col, rs) => rs.getLong(col)) - implicit val shortReader: ColReader[Short] = optReader((col, rs) => rs.getShort(col)) - implicit val stringReader: ColReader[String] = optReader((col, rs) => rs.getString(col)) - implicit val uuidReader: ColReader[UUID] = ColReader[UUID] { (col, row) => - row.uuidOption(col) - } + implicit val jbigDecimalReader: ColReader[java.math.BigDecimal] = ColReader { (col, row) => row.javaBigDecimalOption(col)} + implicit val bigDecimalReader: ColReader[BigDecimal] = ColReader { (col, row) => row.bigDecimalOption(col)} + implicit val jBigIntReader: ColReader[java.math.BigInteger] = ColReader { (col, row) => row.javaBigIntegerOption(col)} + implicit val bigIntReader: ColReader[BigInt] = ColReader { (col, row) => row.bigIntOption(col)} + implicit val boolReader: ColReader[Boolean] = ColReader { (col, row) => row.boolOption(col)} + implicit val byteArrayReader: ColReader[Array[Byte]] = ColReader { (col, row) => row.byteArrayOption(col)} + implicit val byteReader: ColReader[Byte] = ColReader { (col, row) => row.byteOption(col)} + implicit val dateReader: ColReader[Date] = ColReader { (col, row) => row.dateOption(col)} + implicit val instantReader: ColReader[Instant] = ColReader { (col, row) => row.instantOption(col)} + implicit val doubleReader: ColReader[Double] = ColReader { (col, row) => row.doubleOption(col)} + implicit val intReader: ColReader[Int] = ColReader { (col, row) => row.intOption(col)} + implicit val longReader: ColReader[Long] = ColReader { (col, row) => row.longOption(col)} + implicit val shortReader: ColReader[Short] = ColReader { (col, row) => row.shortOption(col)} + implicit val stringReader: ColReader[String] = ColReader { (col, row) => row.stringOption(col)} + implicit val uuidReader: ColReader[UUID] = ColReader[UUID] { (col, row) => row.uuidOption(col)} def enumReader[A <: Enumeration](e: A): ColReader[e.Value] = { intReader.flatMap(id => ColReader[e.Value] { (_, _) => diff --git a/relate/src/test/scala/ColReaderTest.scala b/relate/src/test/scala/ColReaderTest.scala index 4a39dab..437c0f3 100644 --- a/relate/src/test/scala/ColReaderTest.scala +++ b/relate/src/test/scala/ColReaderTest.scala @@ -1,12 +1,18 @@ package com.lucidchart.relate -import java.util.{Date, UUID} -import java.time.Instant +import com.lucidchart.relate.RecordA.mock +import org.mockito.Mockito.when import org.specs2.mock.Mockito import org.specs2.mutable._ +import java.time.Instant +import java.util.{Date, UUID} + case class RecordA( bd: BigDecimal, + bigInt: BigInt, + jBigDecimal: java.math.BigDecimal, + jBigInt: java.math.BigInteger, bool: Boolean, ba: Array[Byte], byte: Byte, @@ -21,11 +27,21 @@ case class RecordA( thing: Things.Value ) +/** + * Ensures that the real apply method is called (whereas `mock[SqlRow]` would 'null' out that method + */ +class MockableRow extends SqlRow(null) { + final override def apply[A: ColReader](col: String): A = super.apply(col) +} + object RecordA extends Mockito { implicit val reader = new RowParser[RecordA] { def parse(row: SqlRow): RecordA = { RecordA( row[BigDecimal]("bd"), + row[BigInt]("bi"), + row[java.math.BigDecimal]("jbd"), + row[java.math.BigInteger]("jbi"), row[Boolean]("bool"), row[Array[Byte]]("ba"), row[Byte]("byte"), @@ -43,29 +59,36 @@ object RecordA extends Mockito { } val timeMillis: Long = 1576179411000l + val uuid: UUID = UUID.fromString("01020304-0506-0708-090a-0b0c0d0e0f10") val mockRow = { - val rs = mock[java.sql.ResultSet] - rs.getBigDecimal("bd") returns new java.math.BigDecimal(10) - rs.getBoolean("bool") returns true - rs.getBytes("ba") returns Array[Byte](1,2,3) - rs.getByte("byte") returns (1: Byte) - rs.getDate("date") returns (new java.sql.Date(10000)) - rs.getTimestamp("instant") returns (new java.sql.Timestamp(timeMillis)) - rs.getDouble("double") returns 1.1 - rs.getInt("int") returns 10 - rs.getLong("long") returns 100L - rs.getShort("short") returns (5: Short) - rs.getString("str") returns "hello" - rs.getObject("uuid") returns Array[Byte](1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16) - rs.getInt("thing") returns 1 - SqlRow(rs) + val row = mock[MockableRow] + row.bigDecimalOption("bd") returns Some(BigDecimal(10)) + row.bigIntOption("bi") returns Some(BigInt(10)) + row.javaBigDecimalOption("jbd") returns Some(new java.math.BigDecimal(10)) + row.javaBigIntegerOption("jbi") returns Some(java.math.BigInteger.valueOf(10)) + row.boolOption("bool") returns Some(true) + row.byteArrayOption("ba") returns Some(Array[Byte](1,2,3)) + row.byteOption("byte") returns Some((1: Byte)) + row.dateOption("date") returns Some((new Date(timeMillis))) + row.instantOption("instant") returns Some(Instant.ofEpochMilli(timeMillis)) + row.doubleOption("double") returns Some(1.1) + row.intOption("int") returns Some(10) + row.longOption("long") returns Some(100l) + row.shortOption("short") returns Some((5: Short)) + row.stringOption("str") returns Some("hello") + row.uuidOption("uuid") returns Some(uuid) + row.intOption("thing") returns Some(1) + row } } case class RecordB( bd: Option[BigDecimal], + bigInt: Option[BigInt], + jBigDecimal: Option[java.math.BigDecimal], + jBigInt: Option[java.math.BigInteger], bool: Option[Boolean], ba: Option[Array[Byte]], byte: Option[Byte], @@ -85,6 +108,9 @@ object RecordB extends Mockito { def parse(row: SqlRow): RecordB = { RecordB( row.opt[BigDecimal]("bd"), + row.opt[BigInt]("bi"), + row.opt[java.math.BigDecimal]("jbd"), + row.opt[java.math.BigInteger]("jbi"), row.opt[Boolean]("bool"), row.opt[Array[Byte]]("ba"), row.opt[Byte]("byte"), @@ -102,14 +128,24 @@ object RecordB extends Mockito { } val mockRow = { - val rs = mock[java.sql.ResultSet] - rs.wasNull() returns true - rs.getBigDecimal("bd") returns null - rs.getBytes("ba") returns null - rs.getDate("date") returns null - rs.getString("str") returns null - rs.getBytes("uuid") returns null - SqlRow(rs) + val row = mock[MockableRow] + row.bigDecimalOption("bd") returns None + row.bigIntOption("bi") returns None + row.javaBigDecimalOption("jbd") returns None + row.javaBigIntegerOption("jbi") returns None + row.boolOption("bool") returns None + row.byteArrayOption("ba") returns None + row.byteOption("byte") returns None + row.dateOption("date") returns None + row.instantOption("instant") returns None + row.doubleOption("double") returns None + row.intOption("int") returns None + row.longOption("long") returns None + row.shortOption("short") returns None + row.stringOption("str") returns None + row.uuidOption("uuid") returns None + row.intOption("thing") returns None + row } } @@ -123,30 +159,34 @@ object Things extends Enumeration { class ColReaderTest extends Specification with Mockito { val mockedInstant = Instant.EPOCH.plusMillis(RecordA.timeMillis) + "ColReader" should { "parse a present values" in { val row = RecordA.mockRow val parsed = RecordA.reader.parse(row) // Arrays use reference equality so we have to check this - // independantly of all the other values + // independently of all the other values val bytes = parsed.ba bytes === Array[Byte](1,2,3) parsed.copy(ba = null) mustEqual RecordA( - BigDecimal(10), - true, - null, - 1, - new Date(10000), - mockedInstant, - 1.1, - 10, - 100, - 5, - "hello", - UUID.fromString("01020304-0506-0708-090a-0b0c0d0e0f10"), - Things.One + bd = BigDecimal(10), + bigInt = BigInt(10), + jBigDecimal = new java.math.BigDecimal(10), + jBigInt = java.math.BigInteger.valueOf(10), + bool = true, + ba = null, + byte = 1, + date = new Date(mockedInstant.toEpochMilli), + instant = mockedInstant, + double = 1.1, + int = 10, + long = 100, + short = 5, + str = "hello", + uuid = RecordA.uuid, + thing = Things.One ) } @@ -155,19 +195,22 @@ class ColReaderTest extends Specification with Mockito { val parsed = RecordB.reader.parse(row) parsed mustEqual RecordB( - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None, - None + bd = None, + bigInt = None, + jBigDecimal = None, + jBigInt = None, + bool = None, + ba = None, + byte = None, + date = None, + instant = None, + double = None, + int = None, + long = None, + short = None, + str = None, + uuid = None, + thing = None ) } @@ -177,23 +220,26 @@ class ColReaderTest extends Specification with Mockito { // Arrays use reference equality so we have to check this // independantly of all the other values - val bytes = parsed.ba.get + val bytes = parsed.ba bytes === Array[Byte](1,2,3) - parsed.copy(ba = null) mustEqual RecordB( - Some(BigDecimal(10)), - Some(true), - null, - Some(1), - Some(new Date(10000)), - Some(mockedInstant), - Some(1.1), - Some(10), - Some(100), - Some(5), - Some("hello"), - Some(UUID.fromString("01020304-0506-0708-090a-0b0c0d0e0f10")), - Some(Things.One) + parsed.copy(ba = None) mustEqual RecordB( + bd = Some(BigDecimal(10)), + bigInt = Some(BigInt(10)), + jBigDecimal = Some(new java.math.BigDecimal(10)), + jBigInt = Some(java.math.BigInteger.valueOf(10)), + bool = Some(true), + ba = None, + byte = Some(1), + date = Some(new Date(mockedInstant.toEpochMilli)), + instant = Some(mockedInstant), + double = Some(1.1), + int = Some(10), + long = Some(100), + short = Some(5), + str = Some("hello"), + uuid = Some(UUID.fromString("01020304-0506-0708-090a-0b0c0d0e0f10")), + thing = Some(Things.One) ) } }