Skip to content

Commit

Permalink
Merge pull request #93 from lucidsoftware/cwoodfield-java-set-iteration
Browse files Browse the repository at this point in the history
Fix interpolation of iterables with inconsistent iteration order
  • Loading branch information
coreywoodfield authored Nov 20, 2024
2 parents 5ef8dbc + 383d813 commit 959e783
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class InterpolatedQuery(protected val parsedQuery: String, protected val params:

protected def applyParams(stmt: PreparedStatement) = parameterize(stmt, 1)

def appendPlaceholders(stringBuilder: StringBuilder) = stringBuilder ++= parsedQuery
def placeholder = parsedQuery

def withTimeout(seconds: Int): InterpolatedQuery = new InterpolatedQuery(parsedQuery, params) {
override protected def normalStatement(implicit conn: Connection) = new BaseStatement(conn)
Expand All @@ -37,13 +37,8 @@ class InterpolatedQuery(protected val parsedQuery: String, protected val params:
object InterpolatedQuery {

def fromParts(parts: Seq[String], params: Seq[Parameter]) = {
val stringBuilder = new StringBuilder()
parts.zip(params).foreach { case (part, param) =>
stringBuilder ++= part
param.appendPlaceholders(stringBuilder)
}
stringBuilder ++= parts.last
new InterpolatedQuery(stringBuilder.toString(), params)
val query = StringContext.standardInterpolator(identity, params.map(_.placeholder), parts)
new InterpolatedQuery(query, params)
}

}
30 changes: 9 additions & 21 deletions relate/src/main/scala/com/lucidchart/relate/Parameters.scala
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ import scala.language.implicitConversions
*/

trait Parameter {
def appendPlaceholders(stringBuilder: StringBuilder)
def placeholder: String
def parameterize(statement: PreparedStatement, i: Int): Int
}

Expand Down Expand Up @@ -616,7 +616,7 @@ object Parameter {
trait SingleParameter extends Parameter {
protected[this] def set(statement: PreparedStatement, i: Int)

def appendPlaceholders(stringBuilder: StringBuilder) = stringBuilder.append("?")
def placeholder = "?"
def parameterize(statement: PreparedStatement, i: Int) = {
set(statement, i)
i + 1
Expand All @@ -632,29 +632,17 @@ trait MultipleParameter extends Parameter {
}
}

class TupleParameter(val params: Iterable[SingleParameter]) extends MultipleParameter {
def appendPlaceholders(stringBuilder: StringBuilder) =
params.zipWithIndex.foreach { case (param, index) =>
if (0 < index) {
stringBuilder.append(",")
}
param.appendPlaceholders(stringBuilder)
}
class TupleParameter(_params: Iterable[SingleParameter]) extends MultipleParameter {
// get a `Seq` to make sure placeholder and parameterize get the same ordering
override protected val params = _params.toSeq
def placeholder = params.iterator.map(_.placeholder).mkString(",")
}

object TupleParameter {
def apply(params: SingleParameter*) = new TupleParameter(params)
}

class TuplesParameter(val params: Iterable[TupleParameter]) extends MultipleParameter {
def appendPlaceholders(stringBuilder: StringBuilder) = {
if (params.nonEmpty) {
params.foreach { param =>
stringBuilder.append("(")
param.appendPlaceholders(stringBuilder)
stringBuilder.append("),")
}
stringBuilder.setLength(stringBuilder.length - 1)
}
}
class TuplesParameter(_params: Iterable[TupleParameter]) extends MultipleParameter {
override protected val params = _params.toSeq
def placeholder = if (params.isEmpty) "" else params.map(_.placeholder).mkString("(", "),(", ")")
}
26 changes: 25 additions & 1 deletion relate/src/test/scala/ParameterizationTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package com.lucidchart.relate

import java.sql.PreparedStatement
import org.specs2.mutable._
import scala.jdk.CollectionConverters._

class ParameterizationTest extends Specification {
"parameter conversions" should {
Expand All @@ -17,14 +18,37 @@ class ParameterizationTest extends Specification {
val querySql = sql"INSERT INTO myTable (foo) VALUES ($longArrayParam)"
querySql.toString mustEqual "INSERT INTO myTable (foo) VALUES (?,?,?)"
}

"interpolate HashSets properly" in {
val hashSet: Set[Int] = scala.collection.immutable.HashSet(1, 2, 3)
val querySql = sql"SELECT * FROM myTable WHERE id IN ($hashSet)"
querySql.toString mustEqual "SELECT * FROM myTable WHERE id IN (?,?,?)"
}

"provide placeholders and parameters in the same order" in {
val setParams = scala.collection.mutable.Map.empty[Int, Int]
case class Param(int: Int) extends SingleParameter {
protected[this] def set(statement: PreparedStatement, i: Int) = setParams(i) = int
override def placeholder = int.toString
}

val paramsSet: Parameter = Set(Param(1), Param(2), Param(3))
paramsSet.parameterize(null, 1)
val query = sql"SELECT * FROM myTable WHERE id IN ($paramsSet)".toString

setParams must haveSize(3)
val order = setParams.toSeq.sortBy(_._1).map(_._2).mkString(",")
// the order of the "placeholders" should match the order of the parameters
query mustEqual s"SELECT * FROM myTable WHERE id IN ($order)"
}
}

"tuple paramater" should {
"use sub-parameter placeholders" in {
class CustomParameter(value: Int) extends SingleParameter {
protected[this] def set(statement: PreparedStatement, i: Int) =
implicitly[Parameterizable[Int]].set(statement, i, value)
override def appendPlaceholders(stringBuilder: StringBuilder) = stringBuilder.append("?::smallint")
override def placeholder = "?::smallint"
}
val querySql = sql"INSERT INTO myTable (foo, bar) VALUES (${(1, new CustomParameter(1))})"
querySql.toString mustEqual "INSERT INTO myTable (foo, bar) VALUES (?,?::smallint)"
Expand Down

0 comments on commit 959e783

Please sign in to comment.