Skip to content

Commit

Permalink
Make null param error more informative (#809)
Browse files Browse the repository at this point in the history
  • Loading branch information
ulfryk committed Nov 19, 2024
1 parent b8ce779 commit af96924
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 6 deletions.
24 changes: 18 additions & 6 deletions modules/core/src/main/scala/doobie/util/put.scala
Original file line number Diff line number Diff line change
Expand Up @@ -56,25 +56,37 @@ sealed abstract class Put[A](
) {}

def unsafeSetNonNullable(ps: PreparedStatement, n: Int, a: A): Unit =
if (a == null) sys.error("oops, null")
else put.fi.apply(ps, n, (put.k(a)))
unsafeSetWhatShouldNotBeNull(ps, n, a, expectedNonNullableParam)

def unsafeSetNullable(ps: PreparedStatement, n: Int, oa: Option[A]): Unit =
oa match {
case Some(a) => unsafeSetNonNullable(ps, n, a)
case Some(a) => unsafeSetWhatShouldNotBeNull(ps, n, a, expectedOptionalParam)
case None => unsafeSetNull(ps, n)
}

private def unsafeSetWhatShouldNotBeNull(ps: PreparedStatement, n: Int, a: A, message: String): Unit =
if (a == null) sys.error(message)
else put.fi.apply(ps, n, (put.k(a)))

def unsafeUpdateNonNullable(rs: ResultSet, n: Int, a: A): Unit =
if (a == null) sys.error("oops, null")
else update.fi.apply(rs, n, (update.k(a)))
unsafeUpdateWhatShouldNotBeNull(rs, n, a, expectedNonNullableParam)

def unsafeUpdateNullable(rs: ResultSet, n: Int, oa: Option[A]): Unit =
oa match {
case Some(a) => unsafeUpdateNonNullable(rs, n, a)
case Some(a) => unsafeUpdateWhatShouldNotBeNull(rs, n, a, expectedOptionalParam)
case None => rs.updateNull(n)
}

private def unsafeUpdateWhatShouldNotBeNull(rs: ResultSet, n: Int, a: A, message: String): Unit =
if (a == null) sys.error(message)
else update.fi.apply(rs, n, (update.k(a)))

private val expectedNonNullableParam =
"Expected non-nullable param. Use Option to describe nullable values."

private val expectedOptionalParam =
"Expected optional param but got Some(null)."

override def toString(): String = {
s"Put(typeStack=${typeStack.mkString_(",")}, jdbcTargets=${jdbcTargets.mkString_(
",")}, vendorTypeNames=${vendorTypeNames.mkString_(",")})"
Expand Down
23 changes: 23 additions & 0 deletions modules/core/src/test/scala/doobie/util/WriteSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,27 @@ class WriteSuite extends munit.FunSuite with WriteSuitePlatform {
.unsafeRunSync()
}

test("Write should yield correct error when Some(null) inserted") {
interceptMessage[RuntimeException]("Expected optional param but got Some(null).") {
testNullPut(("a", Some(null)))
}
}

test("Write should yield correct error when null inserted into non-nullable field") {
interceptMessage[RuntimeException]("Expected non-nullable param. Use Option to describe nullable values.") {
testNullPut((null, Some("b")))
}
}

private def testNullPut(input: (String, Option[String])): Int = {
import doobie.implicits.*

(for {
_ <- sql"create temp table t0 (a text, b text null)".update.run
n <- Update[(String, Option[String])]("insert into t0 (a, b) values (?, ?)").run(input)
} yield n)
.transact(xa)
.unsafeRunSync()
}

}

0 comments on commit af96924

Please sign in to comment.