Skip to content

Commit

Permalink
Make ColReaders consistent with the standard parsing functions/behavi…
Browse files Browse the repository at this point in the history
…or on SqlRow
  • Loading branch information
richard-shurtz committed Sep 25, 2023
1 parent 2964822 commit d5db1b3
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 103 deletions.
50 changes: 15 additions & 35 deletions relate/src/main/scala/com/lucidchart/relate/ColReader.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.lucidchart.relate

import java.nio.ByteBuffer
import java.sql.ResultSet
import java.time.Instant
import java.util.{Date, UUID}

Expand Down Expand Up @@ -36,39 +34,21 @@ object ColReader {
def read(col: String, rs: SqlRow): Option[A] = f(col, rs)
}

def option[A](x: A, rs: ResultSet): Option[A] = {
if (rs.wasNull()) {
None
} else {
Some(x)
}
}

private def optReader[A](f: (String, ResultSet) => A): ColReader[A] = ColReader[A] { (col, row) =>
option(f(col, row.resultSet), row.resultSet)
}

implicit val jbigDecimalReader: ColReader[java.math.BigDecimal] = ColReader[java.math.BigDecimal] { (col, row) =>
option(row.resultSet.getBigDecimal(col), row.resultSet)
}

implicit val bigDecimalReader: ColReader[BigDecimal] = jbigDecimalReader.map(BigDecimal(_))

implicit val bigIntReader: ColReader[BigInt] = jbigDecimalReader.map(bd => BigInt(bd.longValue))

implicit val boolReader: ColReader[Boolean] = optReader((col, rs) => rs.getBoolean(col))
implicit val byteArrayReader: ColReader[Array[Byte]] = optReader((col, rs) => rs.getBytes(col))
implicit val byteReader: ColReader[Byte] = optReader((col, rs) => rs.getByte(col))
implicit val dateReader: ColReader[Date] = optReader((col, rs) => rs.getDate(col))
implicit val instantReader: ColReader[Instant] = optReader((col, rs) => rs.getTimestamp(col)).map(_.toInstant)
implicit val doubleReader: ColReader[Double] = optReader((col, rs) => rs.getDouble(col))
implicit val intReader: ColReader[Int] = optReader((col, rs) => rs.getInt(col))
implicit val longReader: ColReader[Long] = optReader((col, rs) => rs.getLong(col))
implicit val shortReader: ColReader[Short] = optReader((col, rs) => rs.getShort(col))
implicit val stringReader: ColReader[String] = optReader((col, rs) => rs.getString(col))
implicit val uuidReader: ColReader[UUID] = ColReader[UUID] { (col, row) =>
row.uuidOption(col)
}
implicit val jbigDecimalReader: ColReader[java.math.BigDecimal] = ColReader { (col, row) => row.javaBigDecimalOption(col)}
implicit val bigDecimalReader: ColReader[BigDecimal] = ColReader { (col, row) => row.bigDecimalOption(col)}
implicit val jBigIntReader: ColReader[java.math.BigInteger] = ColReader { (col, row) => row.javaBigIntegerOption(col)}
implicit val bigIntReader: ColReader[BigInt] = ColReader { (col, row) => row.bigIntOption(col)}
implicit val boolReader: ColReader[Boolean] = ColReader { (col, row) => row.boolOption(col)}
implicit val byteArrayReader: ColReader[Array[Byte]] = ColReader { (col, row) => row.byteArrayOption(col)}
implicit val byteReader: ColReader[Byte] = ColReader { (col, row) => row.byteOption(col)}
implicit val dateReader: ColReader[Date] = ColReader { (col, row) => row.dateOption(col)}
implicit val instantReader: ColReader[Instant] = ColReader { (col, row) => row.instantOption(col)}
implicit val doubleReader: ColReader[Double] = ColReader { (col, row) => row.doubleOption(col)}
implicit val intReader: ColReader[Int] = ColReader { (col, row) => row.intOption(col)}
implicit val longReader: ColReader[Long] = ColReader { (col, row) => row.longOption(col)}
implicit val shortReader: ColReader[Short] = ColReader { (col, row) => row.shortOption(col)}
implicit val stringReader: ColReader[String] = ColReader { (col, row) => row.stringOption(col)}
implicit val uuidReader: ColReader[UUID] = ColReader[UUID] { (col, row) => row.uuidOption(col)}

def enumReader[A <: Enumeration](e: A): ColReader[e.Value] = {
intReader.flatMap(id => ColReader[e.Value] { (_, _) =>
Expand Down
183 changes: 115 additions & 68 deletions relate/src/test/scala/ColReaderTest.scala
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
package com.lucidchart.relate

import java.util.{Date, UUID}
import java.time.Instant
import com.lucidchart.relate.RecordA.mock
import org.mockito.Mockito.when
import org.specs2.mock.Mockito
import org.specs2.mutable._

import java.time.Instant
import java.util.{Date, UUID}

case class RecordA(
bd: BigDecimal,
bigInt: BigInt,
jBigDecimal: java.math.BigDecimal,
jBigInt: java.math.BigInteger,
bool: Boolean,
ba: Array[Byte],
byte: Byte,
Expand All @@ -21,11 +27,22 @@ case class RecordA(
thing: Things.Value
)

/**
* Ensures that the real apply/opt methods are called (whereas `mock[SqlRow]` would 'null' out that method
*/
class MockableRow extends SqlRow(null) {
final override def apply[A: ColReader](col: String): A = super.apply(col)
final override def opt[A: ColReader](col: String): Option[A] = super.opt(col)
}

object RecordA extends Mockito {
implicit val reader = new RowParser[RecordA] {
def parse(row: SqlRow): RecordA = {
RecordA(
row[BigDecimal]("bd"),
row[BigInt]("bi"),
row[java.math.BigDecimal]("jbd"),
row[java.math.BigInteger]("jbi"),
row[Boolean]("bool"),
row[Array[Byte]]("ba"),
row[Byte]("byte"),
Expand All @@ -42,30 +59,37 @@ object RecordA extends Mockito {
}
}

val timeMillis: Long = 1576179411000l
val timeMillis: Long = 1576179411000L
val uuid: UUID = UUID.fromString("01020304-0506-0708-090a-0b0c0d0e0f10")

val mockRow = {
val rs = mock[java.sql.ResultSet]
rs.getBigDecimal("bd") returns new java.math.BigDecimal(10)
rs.getBoolean("bool") returns true
rs.getBytes("ba") returns Array[Byte](1,2,3)
rs.getByte("byte") returns (1: Byte)
rs.getDate("date") returns (new java.sql.Date(10000))
rs.getTimestamp("instant") returns (new java.sql.Timestamp(timeMillis))
rs.getDouble("double") returns 1.1
rs.getInt("int") returns 10
rs.getLong("long") returns 100L
rs.getShort("short") returns (5: Short)
rs.getString("str") returns "hello"
rs.getObject("uuid") returns Array[Byte](1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16)
rs.getInt("thing") returns 1
SqlRow(rs)
val row = mock[MockableRow]
row.bigDecimalOption("bd") returns Some(BigDecimal(10))
row.bigIntOption("bi") returns Some(BigInt(10))
row.javaBigDecimalOption("jbd") returns Some(new java.math.BigDecimal(10))
row.javaBigIntegerOption("jbi") returns Some(java.math.BigInteger.valueOf(10))
row.boolOption("bool") returns Some(true)
row.byteArrayOption("ba") returns Some(Array[Byte](1,2,3))
row.byteOption("byte") returns Some((1: Byte))
row.dateOption("date") returns Some((new Date(timeMillis)))
row.instantOption("instant") returns Some(Instant.ofEpochMilli(timeMillis))
row.doubleOption("double") returns Some(1.1)
row.intOption("int") returns Some(10)
row.longOption("long") returns Some(100L)
row.shortOption("short") returns Some((5: Short))
row.stringOption("str") returns Some("hello")
row.uuidOption("uuid") returns Some(uuid)
row.intOption("thing") returns Some(1)
row
}

}

case class RecordB(
bd: Option[BigDecimal],
bigInt: Option[BigInt],
jBigDecimal: Option[java.math.BigDecimal],
jBigInt: Option[java.math.BigInteger],
bool: Option[Boolean],
ba: Option[Array[Byte]],
byte: Option[Byte],
Expand All @@ -85,6 +109,9 @@ object RecordB extends Mockito {
def parse(row: SqlRow): RecordB = {
RecordB(
row.opt[BigDecimal]("bd"),
row.opt[BigInt]("bi"),
row.opt[java.math.BigDecimal]("jbd"),
row.opt[java.math.BigInteger]("jbi"),
row.opt[Boolean]("bool"),
row.opt[Array[Byte]]("ba"),
row.opt[Byte]("byte"),
Expand All @@ -102,14 +129,24 @@ object RecordB extends Mockito {
}

val mockRow = {
val rs = mock[java.sql.ResultSet]
rs.wasNull() returns true
rs.getBigDecimal("bd") returns null
rs.getBytes("ba") returns null
rs.getDate("date") returns null
rs.getString("str") returns null
rs.getBytes("uuid") returns null
SqlRow(rs)
val row = mock[MockableRow]
row.bigDecimalOption("bd") returns None
row.bigIntOption("bi") returns None
row.javaBigDecimalOption("jbd") returns None
row.javaBigIntegerOption("jbi") returns None
row.boolOption("bool") returns None
row.byteArrayOption("ba") returns None
row.byteOption("byte") returns None
row.dateOption("date") returns None
row.instantOption("instant") returns None
row.doubleOption("double") returns None
row.intOption("int") returns None
row.longOption("long") returns None
row.shortOption("short") returns None
row.stringOption("str") returns None
row.uuidOption("uuid") returns None
row.intOption("thing") returns None
row
}

}
Expand All @@ -123,30 +160,34 @@ object Things extends Enumeration {

class ColReaderTest extends Specification with Mockito {
val mockedInstant = Instant.EPOCH.plusMillis(RecordA.timeMillis)

"ColReader" should {
"parse a present values" in {
val row = RecordA.mockRow
val parsed = RecordA.reader.parse(row)

// Arrays use reference equality so we have to check this
// independantly of all the other values
// independently of all the other values
val bytes = parsed.ba
bytes === Array[Byte](1,2,3)

parsed.copy(ba = null) mustEqual RecordA(
BigDecimal(10),
true,
null,
1,
new Date(10000),
mockedInstant,
1.1,
10,
100,
5,
"hello",
UUID.fromString("01020304-0506-0708-090a-0b0c0d0e0f10"),
Things.One
bd = BigDecimal(10),
bigInt = BigInt(10),
jBigDecimal = new java.math.BigDecimal(10),
jBigInt = java.math.BigInteger.valueOf(10),
bool = true,
ba = null,
byte = 1,
date = new Date(mockedInstant.toEpochMilli),
instant = mockedInstant,
double = 1.1,
int = 10,
long = 100,
short = 5,
str = "hello",
uuid = RecordA.uuid,
thing = Things.One
)
}

Expand All @@ -155,19 +196,22 @@ class ColReaderTest extends Specification with Mockito {
val parsed = RecordB.reader.parse(row)

parsed mustEqual RecordB(
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None,
None
bd = None,
bigInt = None,
jBigDecimal = None,
jBigInt = None,
bool = None,
ba = None,
byte = None,
date = None,
instant = None,
double = None,
int = None,
long = None,
short = None,
str = None,
uuid = None,
thing = None
)
}

Expand All @@ -177,23 +221,26 @@ class ColReaderTest extends Specification with Mockito {

// Arrays use reference equality so we have to check this
// independantly of all the other values
val bytes = parsed.ba.get
val bytes: Array[Byte] = parsed.ba.get
bytes === Array[Byte](1,2,3)

parsed.copy(ba = null) mustEqual RecordB(
Some(BigDecimal(10)),
Some(true),
null,
Some(1),
Some(new Date(10000)),
Some(mockedInstant),
Some(1.1),
Some(10),
Some(100),
Some(5),
Some("hello"),
Some(UUID.fromString("01020304-0506-0708-090a-0b0c0d0e0f10")),
Some(Things.One)
parsed.copy(ba = None) mustEqual RecordB(
bd = Some(BigDecimal(10)),
bigInt = Some(BigInt(10)),
jBigDecimal = Some(new java.math.BigDecimal(10)),
jBigInt = Some(java.math.BigInteger.valueOf(10)),
bool = Some(true),
ba = None,
byte = Some(1),
date = Some(new Date(mockedInstant.toEpochMilli)),
instant = Some(mockedInstant),
double = Some(1.1),
int = Some(10),
long = Some(100),
short = Some(5),
str = Some("hello"),
uuid = Some(UUID.fromString("01020304-0506-0708-090a-0b0c0d0e0f10")),
thing = Some(Things.One)
)
}
}
Expand Down

0 comments on commit d5db1b3

Please sign in to comment.