From 9da28a396e6c461a0a1f2a20f167632dbe968e73 Mon Sep 17 00:00:00 2001 From: Vitalii Lagutin Date: Thu, 28 Dec 2023 23:29:13 +0200 Subject: [PATCH] Fix multiple IN interpolation (#190) * fix using multiple IN interpolations * fix test:compile against scala v2.12.18 * extends tests to cover more cases of IN interpolation * fix formatting --- .../src/main/scala/zio/jdbc/SqlFragment.scala | 17 ++++---- .../src/main/scala/zio/jdbc/ZConnection.scala | 41 ++++++++++--------- .../test/scala/zio/jdbc/SqlFragmentSpec.scala | 15 +++++++ .../scala/zio/jdbc/ZConnectionPoolSpec.scala | 26 ++++++++++++ 4 files changed, 71 insertions(+), 28 deletions(-) diff --git a/core/src/main/scala/zio/jdbc/SqlFragment.scala b/core/src/main/scala/zio/jdbc/SqlFragment.scala index 9b67801a..ac6e4602 100644 --- a/core/src/main/scala/zio/jdbc/SqlFragment.scala +++ b/core/src/main/scala/zio/jdbc/SqlFragment.scala @@ -113,27 +113,28 @@ sealed trait SqlFragment { self => foreachSegment { syntax => sql.append(syntax.value) } { param => + var size = 0 param.value match { case iterable: Iterable[_] => - iterable.iterator.foreach { item => + iterable.foreach { item => paramsBuilder += item.toString + size += 1 } - sql.append( - Seq.fill(iterable.iterator.size)("?").mkString(",") - ) case array: Array[_] => array.foreach { item => paramsBuilder += item.toString + size += 1 } - sql.append( - Seq.fill(array.length)("?").mkString(",") - ) case _ => - sql.append("?") paramsBuilder += param.value.toString + size += 1 } + sql.append( + if (size == 1) "?" + else Seq.fill(size)("?").mkString(",") + ) } val params = paramsBuilder.result() diff --git a/core/src/main/scala/zio/jdbc/ZConnection.scala b/core/src/main/scala/zio/jdbc/ZConnection.scala index 6128e79f..f740f65d 100644 --- a/core/src/main/scala/zio/jdbc/ZConnection.scala +++ b/core/src/main/scala/zio/jdbc/ZConnection.scala @@ -43,31 +43,32 @@ final class ZConnection(private[jdbc] val underlying: ZConnection.Restorable) ex accessZIO { connection => for { transactionIsolationLevel <- currentTransactionIsolationLevel.get - statement <- ZIO.acquireRelease(ZIO.attempt { - val sb = new StringBuilder() - sql.foreachSegment(syntax => sb.append(syntax.value)) { param => - param.value match { - case iterable: Iterable[_] => - sb.append( - Seq.fill(iterable.iterator.size)("?").mkString(", ") - ) - - case _ => sb.append("?") + statement <- ZIO + .acquireRelease(ZIO.attempt { + val sb = new StringBuilder() + sql.foreachSegment(syntax => sb.append(syntax.value)) { param => + val placeholder = param.value match { + case iterable: Iterable[_] => Seq.fill(iterable.size)("?").mkString(", ") + case _ => "?" + } + sb.append(placeholder) } - } - transactionIsolationLevel.foreach { transactionIsolationLevel => - connection.setTransactionIsolation(transactionIsolationLevel.toInt) - } - connection.prepareStatement( - sb.toString, - if (returnAutoGeneratedKeys) Statement.RETURN_GENERATED_KEYS else Statement.NO_GENERATED_KEYS - ) - })(statement => ZIO.attemptBlocking(statement.close()).ignoreLogged) + transactionIsolationLevel.foreach { transactionIsolationLevel => + connection.setTransactionIsolation(transactionIsolationLevel.toInt) + } + connection.prepareStatement( + sb.result(), + if (returnAutoGeneratedKeys) Statement.RETURN_GENERATED_KEYS else Statement.NO_GENERATED_KEYS + ) + })(statement => ZIO.attemptBlocking(statement.close()).ignoreLogged) _ <- ZIO.attempt { var paramIndex = 1 sql.foreachSegment(_ => ()) { param => param.setter.setValue(statement, paramIndex, param.value) - paramIndex += 1 + paramIndex += (param.value match { + case iterable: Iterable[_] => iterable.size + case _ => 1 + }) } } result <- ZIO diff --git a/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala b/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala index be22de42..c875d24a 100644 --- a/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala +++ b/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala @@ -159,6 +159,21 @@ object SqlFragmentSpec extends ZIOSpecDefault { assertIn(List(1, 2, 3)) && assertIn(Vector(1, 2, 3)) && assertIn(Set(1, 2, 3)) + } + test("interpolation params are supported multiple collections") { + val chunk = Chunk(1, 2, 3) + val list = List(4, 5) + val vector = Vector(6) + val set = Set(7, 8, 9) + assertTrue( + sql"select name, age from users where (1 = 0 or id in ($chunk) or id in ($list) or id in ($vector) or id in ($set))".toString == + "Sql(select name, age from users where (1 = 0 or id in (?,?,?) or id in (?,?) or id in (?) or id in (?,?,?)), 1, 2, 3, 4, 5, 6, 7, 8, 9)" + ) + } + test("interpolation param is supported empty collections") { + val empty = Chunk.empty[Int] + assertTrue( + sql"select name, age from users where id in ($empty)".toString == + "Sql(select name, age from users where id in ())" + ) } } + test("not in") { diff --git a/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala b/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala index aa5b2f1c..6908e684 100644 --- a/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala +++ b/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala @@ -301,6 +301,32 @@ object ZConnectionPoolSpec extends ZIOSpecDefault { testResult <- asserttions } yield testResult } + + test("select all multiple in") { + val names1 = Vector(sherlockHolmes.name, johnWatson.name) + val names2 = Chunk(johnDoe.name) + val namesToSearch = Chunk.fromIterable(names1) ++ names2 + + for { + _ <- createUsers *> insertSherlock *> insertWatson *> insertJohn + users <- transaction { + sql"select name, age from users where name IN ($names1) OR name in ($names2)" + .query[User] + .selectAll + } + } yield assertTrue(users.map(_.name) == namesToSearch) + } + + test("select all in empty") { + val empty = Chunk.empty[String] + + for { + _ <- createUsers *> insertSherlock + users <- transaction { + sql"select name, age from users where name IN ($empty)" + .query[User] + .selectAll + } + } yield assertTrue(users.isEmpty) + } + test("select stream") { for { _ <- createUsers *> insertSherlock *> insertWatson