Skip to content

Commit

Permalink
Make ColReaders consistent with the standard parsing functions/behavi…
Browse files Browse the repository at this point in the history
…or on SqlRow
  • Loading branch information
richard-shurtz committed Sep 26, 2023
1 parent 2964822 commit f9bcb05
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -44,23 +44,22 @@ case class Big(


class RowParserTest extends Specification with Mockito {

"RowParser def macros" should {
"generate parser" in {
val rs = mock[java.sql.ResultSet]
rs.getString("firstName") returns "hi"
rs.getInt("b") returns 20
val row = SqlRow(rs)
val row = mock[MockableRow]
row.stringOption("firstName") returns Some("hi")
row.intOption("b") returns Some(20)

val p = generateParser[Thing]

p.parse(row) mustEqual(Thing("hi", Some(20)))
}

"generate parser w/snake_case columns" in {
val rs = mock[java.sql.ResultSet]
rs.getString("first_name") returns "gregg"
rs.getInt("b") returns 20
val row = SqlRow(rs)
val row = mock[MockableRow]
row.stringOption("first_name") returns Some("gregg")
row.intOption("b") returns Some(20)

val p = generateSnakeParser[Thing]

Expand All @@ -73,10 +72,9 @@ class RowParserTest extends Specification with Mockito {
"lastName" -> "lname"
))

val rs = mock[java.sql.ResultSet]
rs.getString("fname") returns "gregg"
rs.getString("lname") returns "hernandez"
val row = SqlRow(rs)
val row = mock[MockableRow]
row.stringOption("fname") returns Some("gregg")
row.stringOption("lname") returns Some("hernandez")

p.parse(row) mustEqual User("gregg", "hernandez")
}
Expand All @@ -87,10 +85,9 @@ class RowParserTest extends Specification with Mockito {
("lastName", "lname")
))

val rs = mock[java.sql.ResultSet]
rs.getString("fname") returns "gregg"
rs.getString("lname") returns "hernandez"
val row = SqlRow(rs)
val row = mock[MockableRow]
row.stringOption("fname") returns Some("gregg")
row.stringOption("lname") returns Some("hernandez")

p.parse(row) mustEqual User("gregg", "hernandez")
}
Expand All @@ -100,21 +97,19 @@ class RowParserTest extends Specification with Mockito {
"firstName" -> "fname"
))

val rs = mock[java.sql.ResultSet]
rs.getString("fname") returns "gregg"
rs.getString("lastName") returns "hernandez"
val row = SqlRow(rs)
val row = mock[MockableRow]
row.stringOption("fname") returns Some("gregg")
row.stringOption("lastName") returns Some("hernandez")

p.parse(row) mustEqual User("gregg", "hernandez")
}

"generate parser for a case class > 22 fields" in {
val rs = mock[java.sql.ResultSet]
for (i <- (1 to 9)) { rs.getInt(s"f${i}") returns i }
for (i <- (10 to 19)) { rs.getInt(s"z${i}") returns i }
for (i <- (20 to 25)) { rs.getInt(s"a${i}") returns i }
val row = mock[MockableRow]
for (i <- (1 to 9)) { row.intOption(s"f${i}") returns Some(i) }
for (i <- (10 to 19)) { row.intOption(s"z${i}") returns Some(i) }
for (i <- (20 to 25)) { row.intOption(s"a${i}") returns Some(i) }

val row = SqlRow(rs)

val p = generateParser[Big]

Expand Down Expand Up @@ -154,10 +149,9 @@ class RowParserTest extends Specification with Mockito {
case class SimpleRecord(firstName: String, lastName: Option[String])

"generate parser" in {
val rs = mock[java.sql.ResultSet]
rs.getString("firstName") returns "gregg"
rs.getString("lastName") returns "hernandez"
val row = SqlRow(rs)
val row = mock[MockableRow]
row.stringOption("firstName") returns Some("gregg")
row.stringOption("lastName") returns Some("hernandez")

implicitly[RowParser[SimpleRecord]].parse(row) mustEqual SimpleRecord("gregg", Some("hernandez"))
}
Expand All @@ -170,10 +164,9 @@ class RowParserTest extends Specification with Mockito {
}

"generate parser w/snake_case columns" in {
val rs = mock[java.sql.ResultSet]
rs.getString("first_name") returns "gregg"
rs.getString("last_name") returns "hernandez"
val row = SqlRow(rs)
val row = mock[MockableRow]
row.stringOption("first_name") returns Some("gregg")
row.stringOption("last_name") returns Some("hernandez")

// verify that this still compiles
SnakeRecord.f()
Expand All @@ -185,10 +178,9 @@ class RowParserTest extends Specification with Mockito {
case class RemapRecord(firstName: String, lastName: String)

"remap column names" in {
val rs = mock[java.sql.ResultSet]
rs.getString("fname") returns "gregg"
rs.getString("lname") returns "hernandez"
val row = SqlRow(rs)
val row = mock[MockableRow]
row.stringOption("fname") returns Some("gregg")
row.stringOption("lname") returns Some("hernandez")

implicitly[RowParser[RemapRecord]].parse(row) mustEqual RemapRecord("gregg", "hernandez")
}
Expand All @@ -197,10 +189,9 @@ class RowParserTest extends Specification with Mockito {
case class RemapTRecord(firstName: String, lastName: String)

"remap column names w/normal tuple syntax" in {
val rs = mock[java.sql.ResultSet]
rs.getString("fname") returns "gregg"
rs.getString("lname") returns "hernandez"
val row = SqlRow(rs)
val row = mock[MockableRow]
row.stringOption("fname") returns Some("gregg")
row.stringOption("lname") returns Some("hernandez")

implicitly[RowParser[RemapTRecord]].parse(row) mustEqual RemapTRecord("gregg", "hernandez")
}
Expand All @@ -209,10 +200,9 @@ class RowParserTest extends Specification with Mockito {
case class RemapSomeRecord(firstName: String, lastName: String)

"remap some column names" in {
val rs = mock[java.sql.ResultSet]
rs.getString("fname") returns "gregg"
rs.getString("lastName") returns "hernandez"
val row = SqlRow(rs)
val row = mock[MockableRow]
row.stringOption("fname") returns Some("gregg")
row.stringOption("lastName") returns Some("hernandez")

implicitly[RowParser[RemapSomeRecord]].parse(row) mustEqual RemapSomeRecord("gregg", "hernandez")
}
Expand Down
50 changes: 15 additions & 35 deletions relate/src/main/scala/com/lucidchart/relate/ColReader.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
package com.lucidchart.relate

import java.nio.ByteBuffer
import java.sql.ResultSet
import java.time.Instant
import java.util.{Date, UUID}

Expand Down Expand Up @@ -36,39 +34,21 @@ object ColReader {
def read(col: String, rs: SqlRow): Option[A] = f(col, rs)
}

def option[A](x: A, rs: ResultSet): Option[A] = {
if (rs.wasNull()) {
None
} else {
Some(x)
}
}

private def optReader[A](f: (String, ResultSet) => A): ColReader[A] = ColReader[A] { (col, row) =>
option(f(col, row.resultSet), row.resultSet)
}

implicit val jbigDecimalReader: ColReader[java.math.BigDecimal] = ColReader[java.math.BigDecimal] { (col, row) =>
option(row.resultSet.getBigDecimal(col), row.resultSet)
}

implicit val bigDecimalReader: ColReader[BigDecimal] = jbigDecimalReader.map(BigDecimal(_))

implicit val bigIntReader: ColReader[BigInt] = jbigDecimalReader.map(bd => BigInt(bd.longValue))

implicit val boolReader: ColReader[Boolean] = optReader((col, rs) => rs.getBoolean(col))
implicit val byteArrayReader: ColReader[Array[Byte]] = optReader((col, rs) => rs.getBytes(col))
implicit val byteReader: ColReader[Byte] = optReader((col, rs) => rs.getByte(col))
implicit val dateReader: ColReader[Date] = optReader((col, rs) => rs.getDate(col))
implicit val instantReader: ColReader[Instant] = optReader((col, rs) => rs.getTimestamp(col)).map(_.toInstant)
implicit val doubleReader: ColReader[Double] = optReader((col, rs) => rs.getDouble(col))
implicit val intReader: ColReader[Int] = optReader((col, rs) => rs.getInt(col))
implicit val longReader: ColReader[Long] = optReader((col, rs) => rs.getLong(col))
implicit val shortReader: ColReader[Short] = optReader((col, rs) => rs.getShort(col))
implicit val stringReader: ColReader[String] = optReader((col, rs) => rs.getString(col))
implicit val uuidReader: ColReader[UUID] = ColReader[UUID] { (col, row) =>
row.uuidOption(col)
}
implicit val jbigDecimalReader: ColReader[java.math.BigDecimal] = ColReader { (col, row) => row.javaBigDecimalOption(col)}
implicit val bigDecimalReader: ColReader[BigDecimal] = ColReader { (col, row) => row.bigDecimalOption(col)}
implicit val jBigIntReader: ColReader[java.math.BigInteger] = ColReader { (col, row) => row.javaBigIntegerOption(col)}
implicit val bigIntReader: ColReader[BigInt] = ColReader { (col, row) => row.bigIntOption(col)}
implicit val boolReader: ColReader[Boolean] = ColReader { (col, row) => row.boolOption(col)}
implicit val byteArrayReader: ColReader[Array[Byte]] = ColReader { (col, row) => row.byteArrayOption(col)}
implicit val byteReader: ColReader[Byte] = ColReader { (col, row) => row.byteOption(col)}
implicit val dateReader: ColReader[Date] = ColReader { (col, row) => row.dateOption(col)}
implicit val instantReader: ColReader[Instant] = ColReader { (col, row) => row.instantOption(col)}
implicit val doubleReader: ColReader[Double] = ColReader { (col, row) => row.doubleOption(col)}
implicit val intReader: ColReader[Int] = ColReader { (col, row) => row.intOption(col)}
implicit val longReader: ColReader[Long] = ColReader { (col, row) => row.longOption(col)}
implicit val shortReader: ColReader[Short] = ColReader { (col, row) => row.shortOption(col)}
implicit val stringReader: ColReader[String] = ColReader { (col, row) => row.stringOption(col)}
implicit val uuidReader: ColReader[UUID] = ColReader[UUID] { (col, row) => row.uuidOption(col)}

def enumReader[A <: Enumeration](e: A): ColReader[e.Value] = {
intReader.flatMap(id => ColReader[e.Value] { (_, _) =>
Expand Down
Loading

0 comments on commit f9bcb05

Please sign in to comment.