diff --git a/modules/core/src/main/scala/doobie/util/put.scala b/modules/core/src/main/scala/doobie/util/put.scala index 2588c8ce9..015a36f26 100644 --- a/modules/core/src/main/scala/doobie/util/put.scala +++ b/modules/core/src/main/scala/doobie/util/put.scala @@ -56,25 +56,37 @@ sealed abstract class Put[A]( ) {} def unsafeSetNonNullable(ps: PreparedStatement, n: Int, a: A): Unit = - if (a == null) sys.error("oops, null") - else put.fi.apply(ps, n, (put.k(a))) + unsafeSetWhatShouldNotBeNull(ps, n, a, expectedNonNullableParam) def unsafeSetNullable(ps: PreparedStatement, n: Int, oa: Option[A]): Unit = oa match { - case Some(a) => unsafeSetNonNullable(ps, n, a) + case Some(a) => unsafeSetWhatShouldNotBeNull(ps, n, a, expectedOptionalParam) case None => unsafeSetNull(ps, n) } + private def unsafeSetWhatShouldNotBeNull(ps: PreparedStatement, n: Int, a: A, message: String): Unit = + if (a == null) sys.error(message) + else put.fi.apply(ps, n, (put.k(a))) + def unsafeUpdateNonNullable(rs: ResultSet, n: Int, a: A): Unit = - if (a == null) sys.error("oops, null") - else update.fi.apply(rs, n, (update.k(a))) + unsafeUpdateWhatShouldNotBeNull(rs, n, a, expectedNonNullableParam) def unsafeUpdateNullable(rs: ResultSet, n: Int, oa: Option[A]): Unit = oa match { - case Some(a) => unsafeUpdateNonNullable(rs, n, a) + case Some(a) => unsafeUpdateWhatShouldNotBeNull(rs, n, a, expectedOptionalParam) case None => rs.updateNull(n) } + private def unsafeUpdateWhatShouldNotBeNull(rs: ResultSet, n: Int, a: A, message: String): Unit = + if (a == null) sys.error(message) + else update.fi.apply(rs, n, (update.k(a))) + + private val expectedNonNullableParam = + "Expected non-nullable param. Use Option to describe nullable values." + + private val expectedOptionalParam = + "Expected optional param but got Some(null)." + override def toString(): String = { s"Put(typeStack=${typeStack.mkString_(",")}, jdbcTargets=${jdbcTargets.mkString_( ",")}, vendorTypeNames=${vendorTypeNames.mkString_(",")})" diff --git a/modules/core/src/test/scala/doobie/util/WriteSuite.scala b/modules/core/src/test/scala/doobie/util/WriteSuite.scala index f90cdf885..9ff8d4883 100644 --- a/modules/core/src/test/scala/doobie/util/WriteSuite.scala +++ b/modules/core/src/test/scala/doobie/util/WriteSuite.scala @@ -122,4 +122,27 @@ class WriteSuite extends munit.FunSuite with WriteSuitePlatform { .unsafeRunSync() } + test("Write should yield correct error when Some(null) inserted") { + interceptMessage[RuntimeException]("Expected optional param but got Some(null).") { + testNullPut(("a", Some(null))) + } + } + + test("Write should yield correct error when null inserted into non-nullable field") { + interceptMessage[RuntimeException]("Expected non-nullable param. Use Option to describe nullable values.") { + testNullPut((null, Some("b"))) + } + } + + private def testNullPut(input: (String, Option[String])): Int = { + import doobie.implicits.* + + (for { + _ <- sql"create temp table t0 (a text, b text null)".update.run + n <- Update[(String, Option[String])]("insert into t0 (a, b) values (?, ?)").run(input) + } yield n) + .transact(xa) + .unsafeRunSync() + } + }