Skip to content

Commit

Permalink
Fix multiple IN interpolation (#190)
Browse files Browse the repository at this point in the history
* fix using multiple IN interpolations

* fix test:compile against scala v2.12.18

* extends tests to cover more cases of IN interpolation

* fix formatting
  • Loading branch information
lvitaly authored Dec 28, 2023
1 parent a56e4f3 commit 9da28a3
Show file tree
Hide file tree
Showing 4 changed files with 71 additions and 28 deletions.
17 changes: 9 additions & 8 deletions core/src/main/scala/zio/jdbc/SqlFragment.scala
Original file line number Diff line number Diff line change
Expand Up @@ -113,27 +113,28 @@ sealed trait SqlFragment { self =>
foreachSegment { syntax =>
sql.append(syntax.value)
} { param =>
var size = 0
param.value match {
case iterable: Iterable[_] =>
iterable.iterator.foreach { item =>
iterable.foreach { item =>
paramsBuilder += item.toString
size += 1
}
sql.append(
Seq.fill(iterable.iterator.size)("?").mkString(",")
)

case array: Array[_] =>
array.foreach { item =>
paramsBuilder += item.toString
size += 1
}
sql.append(
Seq.fill(array.length)("?").mkString(",")
)

case _ =>
sql.append("?")
paramsBuilder += param.value.toString
size += 1
}
sql.append(
if (size == 1) "?"
else Seq.fill(size)("?").mkString(",")
)
}

val params = paramsBuilder.result()
Expand Down
41 changes: 21 additions & 20 deletions core/src/main/scala/zio/jdbc/ZConnection.scala
Original file line number Diff line number Diff line change
Expand Up @@ -43,31 +43,32 @@ final class ZConnection(private[jdbc] val underlying: ZConnection.Restorable) ex
accessZIO { connection =>
for {
transactionIsolationLevel <- currentTransactionIsolationLevel.get
statement <- ZIO.acquireRelease(ZIO.attempt {
val sb = new StringBuilder()
sql.foreachSegment(syntax => sb.append(syntax.value)) { param =>
param.value match {
case iterable: Iterable[_] =>
sb.append(
Seq.fill(iterable.iterator.size)("?").mkString(", ")
)

case _ => sb.append("?")
statement <- ZIO
.acquireRelease(ZIO.attempt {
val sb = new StringBuilder()
sql.foreachSegment(syntax => sb.append(syntax.value)) { param =>
val placeholder = param.value match {
case iterable: Iterable[_] => Seq.fill(iterable.size)("?").mkString(", ")
case _ => "?"
}
sb.append(placeholder)
}
}
transactionIsolationLevel.foreach { transactionIsolationLevel =>
connection.setTransactionIsolation(transactionIsolationLevel.toInt)
}
connection.prepareStatement(
sb.toString,
if (returnAutoGeneratedKeys) Statement.RETURN_GENERATED_KEYS else Statement.NO_GENERATED_KEYS
)
})(statement => ZIO.attemptBlocking(statement.close()).ignoreLogged)
transactionIsolationLevel.foreach { transactionIsolationLevel =>
connection.setTransactionIsolation(transactionIsolationLevel.toInt)
}
connection.prepareStatement(
sb.result(),
if (returnAutoGeneratedKeys) Statement.RETURN_GENERATED_KEYS else Statement.NO_GENERATED_KEYS
)
})(statement => ZIO.attemptBlocking(statement.close()).ignoreLogged)
_ <- ZIO.attempt {
var paramIndex = 1
sql.foreachSegment(_ => ()) { param =>
param.setter.setValue(statement, paramIndex, param.value)
paramIndex += 1
paramIndex += (param.value match {
case iterable: Iterable[_] => iterable.size
case _ => 1
})
}
}
result <- ZIO
Expand Down
15 changes: 15 additions & 0 deletions core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,21 @@ object SqlFragmentSpec extends ZIOSpecDefault {
assertIn(List(1, 2, 3)) &&
assertIn(Vector(1, 2, 3)) &&
assertIn(Set(1, 2, 3))
} + test("interpolation params are supported multiple collections") {
val chunk = Chunk(1, 2, 3)
val list = List(4, 5)
val vector = Vector(6)
val set = Set(7, 8, 9)
assertTrue(
sql"select name, age from users where (1 = 0 or id in ($chunk) or id in ($list) or id in ($vector) or id in ($set))".toString ==
"Sql(select name, age from users where (1 = 0 or id in (?,?,?) or id in (?,?) or id in (?) or id in (?,?,?)), 1, 2, 3, 4, 5, 6, 7, 8, 9)"
)
} + test("interpolation param is supported empty collections") {
val empty = Chunk.empty[Int]
assertTrue(
sql"select name, age from users where id in ($empty)".toString ==
"Sql(select name, age from users where id in ())"
)
}
} +
test("not in") {
Expand Down
26 changes: 26 additions & 0 deletions core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,32 @@ object ZConnectionPoolSpec extends ZIOSpecDefault {
testResult <- asserttions
} yield testResult
} +
test("select all multiple in") {
val names1 = Vector(sherlockHolmes.name, johnWatson.name)
val names2 = Chunk(johnDoe.name)
val namesToSearch = Chunk.fromIterable(names1) ++ names2

for {
_ <- createUsers *> insertSherlock *> insertWatson *> insertJohn
users <- transaction {
sql"select name, age from users where name IN ($names1) OR name in ($names2)"
.query[User]
.selectAll
}
} yield assertTrue(users.map(_.name) == namesToSearch)
} +
test("select all in empty") {
val empty = Chunk.empty[String]

for {
_ <- createUsers *> insertSherlock
users <- transaction {
sql"select name, age from users where name IN ($empty)"
.query[User]
.selectAll
}
} yield assertTrue(users.isEmpty)
} +
test("select stream") {
for {
_ <- createUsers *> insertSherlock *> insertWatson
Expand Down

0 comments on commit 9da28a3

Please sign in to comment.