diff --git a/relate/src/main/scala/com/lucidchart/relate/InterpolatedQuery.scala b/relate/src/main/scala/com/lucidchart/relate/InterpolatedQuery.scala index 4d07dd2..1130278 100644 --- a/relate/src/main/scala/com/lucidchart/relate/InterpolatedQuery.scala +++ b/relate/src/main/scala/com/lucidchart/relate/InterpolatedQuery.scala @@ -10,7 +10,7 @@ class InterpolatedQuery(protected val parsedQuery: String, protected val params: protected def applyParams(stmt: PreparedStatement) = parameterize(stmt, 1) - def appendPlaceholders(stringBuilder: StringBuilder) = stringBuilder ++= parsedQuery + def placeholder = parsedQuery def withTimeout(seconds: Int): InterpolatedQuery = new InterpolatedQuery(parsedQuery, params) { override protected def normalStatement(implicit conn: Connection) = new BaseStatement(conn) @@ -37,13 +37,8 @@ class InterpolatedQuery(protected val parsedQuery: String, protected val params: object InterpolatedQuery { def fromParts(parts: Seq[String], params: Seq[Parameter]) = { - val stringBuilder = new StringBuilder() - parts.zip(params).foreach { case (part, param) => - stringBuilder ++= part - param.appendPlaceholders(stringBuilder) - } - stringBuilder ++= parts.last - new InterpolatedQuery(stringBuilder.toString(), params) + val query = StringContext.standardInterpolator(identity, params.map(_.placeholder), parts) + new InterpolatedQuery(query, params) } } diff --git a/relate/src/main/scala/com/lucidchart/relate/Parameters.scala b/relate/src/main/scala/com/lucidchart/relate/Parameters.scala index b9ca037..cc45908 100644 --- a/relate/src/main/scala/com/lucidchart/relate/Parameters.scala +++ b/relate/src/main/scala/com/lucidchart/relate/Parameters.scala @@ -31,7 +31,7 @@ import scala.language.implicitConversions */ trait Parameter { - def appendPlaceholders(stringBuilder: StringBuilder) + def placeholder: String def parameterize(statement: PreparedStatement, i: Int): Int } @@ -616,7 +616,7 @@ object Parameter { trait SingleParameter extends Parameter { protected[this] def set(statement: PreparedStatement, i: Int) - def appendPlaceholders(stringBuilder: StringBuilder) = stringBuilder.append("?") + def placeholder = "?" def parameterize(statement: PreparedStatement, i: Int) = { set(statement, i) i + 1 @@ -633,15 +633,7 @@ trait MultipleParameter extends Parameter { } class TupleParameter(val params: Iterable[SingleParameter]) extends MultipleParameter { - def appendPlaceholders(stringBuilder: StringBuilder) = - // if we don't use the iterator, we won't necessarily get a consistent iteration order: the element with index 0 - // according to zipWithIndex might not be the first element handled by the foreach - params.iterator.zipWithIndex.foreach { case (param, index) => - if (0 < index) { - stringBuilder.append(",") - } - param.appendPlaceholders(stringBuilder) - } + def placeholder = params.iterator.map(_.placeholder).mkString(",") } object TupleParameter { @@ -649,14 +641,5 @@ object TupleParameter { } class TuplesParameter(val params: Iterable[TupleParameter]) extends MultipleParameter { - def appendPlaceholders(stringBuilder: StringBuilder) = { - if (params.nonEmpty) { - params.foreach { param => - stringBuilder.append("(") - param.appendPlaceholders(stringBuilder) - stringBuilder.append("),") - } - stringBuilder.setLength(stringBuilder.length - 1) - } - } + def placeholder = if (params.isEmpty) "" else params.map(_.placeholder).mkString("(", "),(", ")") } diff --git a/relate/src/test/scala/ParameterizationTest.scala b/relate/src/test/scala/ParameterizationTest.scala index 44730f2..ada03db 100644 --- a/relate/src/test/scala/ParameterizationTest.scala +++ b/relate/src/test/scala/ParameterizationTest.scala @@ -20,11 +20,7 @@ class ParameterizationTest extends Specification { } "interpolate HashSets properly" in { - // note that HashSets don't iterate consistently: zipWithIndex and head get different "first" elements - // (also that zipWithIndex on a HashSet returns another HashSet) val hashSet: Set[Int] = scala.collection.immutable.HashSet(1, 2, 3) - hashSet.zipWithIndex.head._2 mustNotEqual 0 - // even so, we should interpolate it correctly val querySql = sql"SELECT * FROM myTable WHERE id IN ($hashSet)" querySql.toString mustEqual "SELECT * FROM myTable WHERE id IN (?,?,?)" } @@ -35,7 +31,7 @@ class ParameterizationTest extends Specification { class CustomParameter(value: Int) extends SingleParameter { protected[this] def set(statement: PreparedStatement, i: Int) = implicitly[Parameterizable[Int]].set(statement, i, value) - override def appendPlaceholders(stringBuilder: StringBuilder) = stringBuilder.append("?::smallint") + override def placeholder = "?::smallint" } val querySql = sql"INSERT INTO myTable (foo, bar) VALUES (${(1, new CustomParameter(1))})" querySql.toString mustEqual "INSERT INTO myTable (foo, bar) VALUES (?,?::smallint)"