diff --git a/core/src/main/scala/zio/jdbc/SqlFragment.scala b/core/src/main/scala/zio/jdbc/SqlFragment.scala index 799a44d6..3dcdab4f 100644 --- a/core/src/main/scala/zio/jdbc/SqlFragment.scala +++ b/core/src/main/scala/zio/jdbc/SqlFragment.scala @@ -273,6 +273,7 @@ sealed trait SqlFragment { self => private[jdbc] def foreachSegment(addSyntax: Segment.Syntax => Any)(addParam: Segment.Param => Any): Unit = segments.foreach { + case Segment.Empty => () case syntax: Segment.Syntax => addSyntax(syntax) case param: Segment.Param => addParam(param) case nested: Segment.Nested => nested.sql.foreachSegment(addSyntax)(addParam) @@ -289,10 +290,13 @@ object SqlFragment { sealed trait Segment object Segment { + case object Empty extends Segment final case class Syntax(value: String) extends Segment final case class Param(value: Any, setter: Setter[Any]) extends Segment final case class Nested(sql: SqlFragment) extends Segment + @inline def empty: Segment = Empty + implicit def paramSegment[A](a: A)(implicit setter: Setter[A]): Segment.Param = Segment.Param(a, setter.asInstanceOf[Setter[Any]]) @@ -415,7 +419,7 @@ object SqlFragment { forSqlType((ps, i, value) => ps.setObject(i, value, Types.TIMESTAMP_WITH_TIMEZONE), Types.TIMESTAMP_WITH_TIMEZONE) } - def apply(sql: String): SqlFragment = sql + def apply(sql: String): SqlFragment = SqlFragment(Chunk.single(SqlFragment.Segment.Syntax(sql))) def apply(segments: Chunk[Segment]): SqlFragment = SqlFragment.Append(segments) diff --git a/core/src/main/scala/zio/jdbc/SqlInterpolator.scala b/core/src/main/scala/zio/jdbc/SqlInterpolator.scala index c3b0f5ad..bb5c90bb 100644 --- a/core/src/main/scala/zio/jdbc/SqlInterpolator.scala +++ b/core/src/main/scala/zio/jdbc/SqlInterpolator.scala @@ -29,10 +29,8 @@ final class SqlInterpolator(val context: StringContext) extends AnyVal { while (syntaxIterator.hasNext) { val syntax = syntaxIterator.next() - if (syntax.nonEmpty) { - chunkBuilder += SqlFragment.Segment.Syntax(syntax) - if (paramsIterator.hasNext) chunkBuilder += paramsIterator.next() - } + chunkBuilder += (if (syntax.isEmpty) SqlFragment.Segment.empty else SqlFragment.Segment.Syntax(syntax)) + if (paramsIterator.hasNext) chunkBuilder += paramsIterator.next() } while (paramsIterator.hasNext) chunkBuilder += paramsIterator.next() diff --git a/core/src/main/scala/zio/jdbc/package.scala b/core/src/main/scala/zio/jdbc/package.scala index 3f8898d8..649c38ee 100644 --- a/core/src/main/scala/zio/jdbc/package.scala +++ b/core/src/main/scala/zio/jdbc/package.scala @@ -24,7 +24,7 @@ package object jdbc { /** * Converts a String into a pure SQL expression */ - implicit def stringToSql(s: String): SqlFragment = SqlFragment(Chunk(SqlFragment.Segment.Syntax(s))) + implicit def stringToSql(s: String): SqlFragment = SqlFragment(s) /** * A new transaction, which may be applied to ZIO effects that require a diff --git a/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala b/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala index a5a0daa0..a59a33af 100644 --- a/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala +++ b/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala @@ -29,10 +29,15 @@ object SqlFragmentSpec extends ZIOSpecDefault { s"Sql(select name, age from users where id = ?, $id)" ) } + - test("ensure no empty Syntax instances") { + test("Empty Segment instances are insignificant") { val age = 42 val name = "sholmes" - assertTrue(sql"select name, age from users where age = $age and name = $name".segments.size == 4) + val sql = sql"select name, age from users where age = $age and name = $name" + assertTrue( + sql.segments.size == 5, + sql.segments.last == SqlFragment.Segment.Empty, + sql.toString == "Sql(select name, age from users where age = ? and name = ?, 42, sholmes)" + ) } + test("interpolate Sql values") { val tableName = sql"table1" @@ -247,6 +252,15 @@ object SqlFragmentSpec extends ZIOSpecDefault { assertTrue( result.toString == "Sql(UPDATE persons)" ) + } + + test("'interpolation <=> ++ operator' equivalence") { + val s1 = sql"${"1"}::varchar" + val s2 = (sql"${"1"}" ++ sql"::varchar") + + assertTrue( + s1.toString == "Sql(?::varchar, 1)", + s1.toString == s2.toString + ) } } + suite("SqlFragment ResultSet tests") {