diff --git a/relate/src/main/scala/com/lucidchart/relate/ColReader.scala b/relate/src/main/scala/com/lucidchart/relate/ColReader.scala index f48f755..af4d48d 100644 --- a/relate/src/main/scala/com/lucidchart/relate/ColReader.scala +++ b/relate/src/main/scala/com/lucidchart/relate/ColReader.scala @@ -60,12 +60,8 @@ object ColReader { implicit val longReader: ColReader[Long] = optReader((col, rs) => rs.getLong(col)) implicit val shortReader: ColReader[Short] = optReader((col, rs) => rs.getShort(col)) implicit val stringReader: ColReader[String] = optReader((col, rs) => rs.getString(col)) - implicit val uuidReader: ColReader[UUID] = byteArrayReader.map { bytes => - require(bytes.length == 16) - val bb = ByteBuffer.wrap(bytes) - val high = bb.getLong - val low = bb.getLong - new UUID(high, low) + implicit val uuidReader: ColReader[UUID] = ColReader[UUID] { (col, row) => + row.uuidOption(col) } def enumReader[A <: Enumeration](e: A): ColReader[e.Value] = { diff --git a/relate/src/main/scala/com/lucidchart/relate/SqlRow.scala b/relate/src/main/scala/com/lucidchart/relate/SqlRow.scala index 0982277..a173e03 100644 --- a/relate/src/main/scala/com/lucidchart/relate/SqlRow.scala +++ b/relate/src/main/scala/com/lucidchart/relate/SqlRow.scala @@ -173,13 +173,23 @@ class SqlRow(val resultSet: java.sql.ResultSet) extends ResultSetWrapper { def uuid(column: String): UUID = uuidOption(column).get def uuidOption(column: String): Option[UUID] = { - byteArrayOption(column).map { bytes => - require(bytes.length == 16) + extractOption(column) { + case u: UUID => u + case b => { + val bytes = b match { + case x: Array[Byte] => x + case x: Blob => x.getBytes(0, x.length.toInt) + case x: Clob => x.getSubString(1, x.length.asInstanceOf[Int]).getBytes + case x: String => x.toCharArray.map(_.toByte) + } + + require(bytes.length == 16) - val bb = ByteBuffer.wrap(bytes) - val high = bb.getLong - val low = bb.getLong - new UUID(high, low) + val bb = ByteBuffer.wrap(bytes) + val high = bb.getLong + val low = bb.getLong + new UUID(high, low) + } } } diff --git a/relate/src/test/scala/ColReaderTest.scala b/relate/src/test/scala/ColReaderTest.scala index f0a10e0..e53a422 100644 --- a/relate/src/test/scala/ColReaderTest.scala +++ b/relate/src/test/scala/ColReaderTest.scala @@ -51,7 +51,7 @@ object RecordA extends Mockito { rs.getLong("long") returns 100L rs.getShort("short") returns (5: Short) rs.getString("str") returns "hello" - rs.getBytes("uuid") returns Array[Byte](1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16) + rs.getObject("uuid") returns Array[Byte](1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16) rs.getInt("thing") returns 1 SqlRow(rs) } @@ -185,4 +185,22 @@ class ColReaderTest extends Specification with Mockito { ) } } + + "uuidReader" should { + "parse a byte array" in { + val rs = mock[java.sql.ResultSet] + val row = SqlRow(rs) + rs.getObject("col") returns Array[Byte]('0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f') + + ColReader.uuidReader.read("col", row) mustEqual Some(new UUID(3472611983179986487L, 4051376414998685030L)) + } + + "parse a uuid" in { + val rs = mock[java.sql.ResultSet] + val row = SqlRow(rs) + rs.getObject("col") returns new UUID(3472611983179986487L, 4051376414998685030L) + + ColReader.uuidReader.read("col", row) mustEqual Some(new UUID(3472611983179986487L, 4051376414998685030L)) + } + } } diff --git a/relate/src/test/scala/SqlResultSpec.scala b/relate/src/test/scala/SqlResultSpec.scala index 71a7612..6c9a01e 100644 --- a/relate/src/test/scala/SqlResultSpec.scala +++ b/relate/src/test/scala/SqlResultSpec.scala @@ -874,7 +874,7 @@ class SqlResultSpec extends Specification with Mockito { } "uuid" should { - "return the correct value" in { + "return the correct value when stored as a byte array" in { val (rs, row, _) = getMocks val res = Array[Byte]('0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 'a', 'b', 'c', 'd', 'e', 'f') @@ -882,6 +882,15 @@ class SqlResultSpec extends Specification with Mockito { row.uuid("uuid") equals new UUID(3472611983179986487L, 4051376414998685030L) row.uuidOption("uuid") must beSome(new UUID(3472611983179986487L, 4051376414998685030L)) } + + "return the correct value when stored as UUID" in { + val (rs, row, _) = getMocks + + val res = new UUID(3472611983179986487L, 4051376414998685030L) + rs.getObject("uuid") returns res + row.uuid("uuid") equals res + row.uuidOption("uuid") must beSome(res) + } } "uuidFromString" should {