diff --git a/relate/src/main/scala/com/lucidchart/relate/InterpolatedQuery.scala b/relate/src/main/scala/com/lucidchart/relate/InterpolatedQuery.scala index 4d07dd2..1130278 100644 --- a/relate/src/main/scala/com/lucidchart/relate/InterpolatedQuery.scala +++ b/relate/src/main/scala/com/lucidchart/relate/InterpolatedQuery.scala @@ -10,7 +10,7 @@ class InterpolatedQuery(protected val parsedQuery: String, protected val params: protected def applyParams(stmt: PreparedStatement) = parameterize(stmt, 1) - def appendPlaceholders(stringBuilder: StringBuilder) = stringBuilder ++= parsedQuery + def placeholder = parsedQuery def withTimeout(seconds: Int): InterpolatedQuery = new InterpolatedQuery(parsedQuery, params) { override protected def normalStatement(implicit conn: Connection) = new BaseStatement(conn) @@ -37,13 +37,8 @@ class InterpolatedQuery(protected val parsedQuery: String, protected val params: object InterpolatedQuery { def fromParts(parts: Seq[String], params: Seq[Parameter]) = { - val stringBuilder = new StringBuilder() - parts.zip(params).foreach { case (part, param) => - stringBuilder ++= part - param.appendPlaceholders(stringBuilder) - } - stringBuilder ++= parts.last - new InterpolatedQuery(stringBuilder.toString(), params) + val query = StringContext.standardInterpolator(identity, params.map(_.placeholder), parts) + new InterpolatedQuery(query, params) } } diff --git a/relate/src/main/scala/com/lucidchart/relate/Parameters.scala b/relate/src/main/scala/com/lucidchart/relate/Parameters.scala index 45fbac1..80f879d 100644 --- a/relate/src/main/scala/com/lucidchart/relate/Parameters.scala +++ b/relate/src/main/scala/com/lucidchart/relate/Parameters.scala @@ -31,7 +31,7 @@ import scala.language.implicitConversions */ trait Parameter { - def appendPlaceholders(stringBuilder: StringBuilder) + def placeholder: String def parameterize(statement: PreparedStatement, i: Int): Int } @@ -616,7 +616,7 @@ object Parameter { trait SingleParameter extends Parameter { protected[this] def set(statement: PreparedStatement, i: Int) - def appendPlaceholders(stringBuilder: StringBuilder) = stringBuilder.append("?") + def placeholder = "?" def parameterize(statement: PreparedStatement, i: Int) = { set(statement, i) i + 1 @@ -632,29 +632,17 @@ trait MultipleParameter extends Parameter { } } -class TupleParameter(val params: Iterable[SingleParameter]) extends MultipleParameter { - def appendPlaceholders(stringBuilder: StringBuilder) = - params.zipWithIndex.foreach { case (param, index) => - if (0 < index) { - stringBuilder.append(",") - } - param.appendPlaceholders(stringBuilder) - } +class TupleParameter(_params: Iterable[SingleParameter]) extends MultipleParameter { + // get a `Seq` to make sure placeholder and parameterize get the same ordering + override protected val params = _params.toSeq + def placeholder = params.iterator.map(_.placeholder).mkString(",") } object TupleParameter { def apply(params: SingleParameter*) = new TupleParameter(params) } -class TuplesParameter(val params: Iterable[TupleParameter]) extends MultipleParameter { - def appendPlaceholders(stringBuilder: StringBuilder) = { - if (params.nonEmpty) { - params.foreach { param => - stringBuilder.append("(") - param.appendPlaceholders(stringBuilder) - stringBuilder.append("),") - } - stringBuilder.setLength(stringBuilder.length - 1) - } - } +class TuplesParameter(_params: Iterable[TupleParameter]) extends MultipleParameter { + override protected val params = _params.toSeq + def placeholder = if (params.isEmpty) "" else params.map(_.placeholder).mkString("(", "),(", ")") } diff --git a/relate/src/test/scala/ParameterizationTest.scala b/relate/src/test/scala/ParameterizationTest.scala index 1b88bc2..11b2b27 100644 --- a/relate/src/test/scala/ParameterizationTest.scala +++ b/relate/src/test/scala/ParameterizationTest.scala @@ -2,6 +2,7 @@ package com.lucidchart.relate import java.sql.PreparedStatement import org.specs2.mutable._ +import scala.jdk.CollectionConverters._ class ParameterizationTest extends Specification { "parameter conversions" should { @@ -17,6 +18,29 @@ class ParameterizationTest extends Specification { val querySql = sql"INSERT INTO myTable (foo) VALUES ($longArrayParam)" querySql.toString mustEqual "INSERT INTO myTable (foo) VALUES (?,?,?)" } + + "interpolate HashSets properly" in { + val hashSet: Set[Int] = scala.collection.immutable.HashSet(1, 2, 3) + val querySql = sql"SELECT * FROM myTable WHERE id IN ($hashSet)" + querySql.toString mustEqual "SELECT * FROM myTable WHERE id IN (?,?,?)" + } + + "provide placeholders and parameters in the same order" in { + val setParams = scala.collection.mutable.Map.empty[Int, Int] + case class Param(int: Int) extends SingleParameter { + protected[this] def set(statement: PreparedStatement, i: Int) = setParams(i) = int + override def placeholder = int.toString + } + + val paramsSet: Parameter = Set(Param(1), Param(2), Param(3)) + paramsSet.parameterize(null, 1) + val query = sql"SELECT * FROM myTable WHERE id IN ($paramsSet)".toString + + setParams must haveSize(3) + val order = setParams.toSeq.sortBy(_._1).map(_._2).mkString(",") + // the order of the "placeholders" should match the order of the parameters + query mustEqual s"SELECT * FROM myTable WHERE id IN ($order)" + } } "tuple paramater" should { @@ -24,7 +48,7 @@ class ParameterizationTest extends Specification { class CustomParameter(value: Int) extends SingleParameter { protected[this] def set(statement: PreparedStatement, i: Int) = implicitly[Parameterizable[Int]].set(statement, i, value) - override def appendPlaceholders(stringBuilder: StringBuilder) = stringBuilder.append("?::smallint") + override def placeholder = "?::smallint" } val querySql = sql"INSERT INTO myTable (foo, bar) VALUES (${(1, new CustomParameter(1))})" querySql.toString mustEqual "INSERT INTO myTable (foo, bar) VALUES (?,?::smallint)"