From 30a167398de3bac335f864f3c81feef2f5410196 Mon Sep 17 00:00:00 2001 From: jules Ivanic Date: Fri, 27 Oct 2023 12:43:33 +0400 Subject: [PATCH] Fix #172 --- core/src/main/scala/zio/jdbc/SqlFragment.scala | 6 +++++- .../main/scala/zio/jdbc/SqlInterpolator.scala | 6 ++---- core/src/main/scala/zio/jdbc/package.scala | 2 +- .../test/scala/zio/jdbc/SqlFragmentSpec.scala | 18 ++++++++++++++++-- 4 files changed, 24 insertions(+), 8 deletions(-) diff --git a/core/src/main/scala/zio/jdbc/SqlFragment.scala b/core/src/main/scala/zio/jdbc/SqlFragment.scala index 3a60d70f..39b618b4 100644 --- a/core/src/main/scala/zio/jdbc/SqlFragment.scala +++ b/core/src/main/scala/zio/jdbc/SqlFragment.scala @@ -272,6 +272,7 @@ sealed trait SqlFragment { self => private[jdbc] def foreachSegment(addSyntax: Segment.Syntax => Any)(addParam: Segment.Param => Any): Unit = segments.foreach { + case Segment.Empty => () case syntax: Segment.Syntax => addSyntax(syntax) case param: Segment.Param => addParam(param) case nested: Segment.Nested => nested.sql.foreachSegment(addSyntax)(addParam) @@ -288,10 +289,13 @@ object SqlFragment { sealed trait Segment object Segment { + case object Empty extends Segment final case class Syntax(value: String) extends Segment final case class Param(value: Any, setter: Setter[Any]) extends Segment final case class Nested(sql: SqlFragment) extends Segment + @inline def empty: Segment = Empty + implicit def paramSegment[A](a: A)(implicit setter: Setter[A]): Segment.Param = Segment.Param(a, setter.asInstanceOf[Setter[Any]]) @@ -385,7 +389,7 @@ object SqlFragment { implicit val instantSetter: Setter[java.time.Instant] = sqlTimestampSetter.contramap(java.sql.Timestamp.from) } - def apply(sql: String): SqlFragment = sql + def apply(sql: String): SqlFragment = SqlFragment(Chunk.single(SqlFragment.Segment.Syntax(sql))) def apply(segments: Chunk[Segment]): SqlFragment = SqlFragment.Append(segments) diff --git a/core/src/main/scala/zio/jdbc/SqlInterpolator.scala b/core/src/main/scala/zio/jdbc/SqlInterpolator.scala index c3b0f5ad..bb5c90bb 100644 --- a/core/src/main/scala/zio/jdbc/SqlInterpolator.scala +++ b/core/src/main/scala/zio/jdbc/SqlInterpolator.scala @@ -29,10 +29,8 @@ final class SqlInterpolator(val context: StringContext) extends AnyVal { while (syntaxIterator.hasNext) { val syntax = syntaxIterator.next() - if (syntax.nonEmpty) { - chunkBuilder += SqlFragment.Segment.Syntax(syntax) - if (paramsIterator.hasNext) chunkBuilder += paramsIterator.next() - } + chunkBuilder += (if (syntax.isEmpty) SqlFragment.Segment.empty else SqlFragment.Segment.Syntax(syntax)) + if (paramsIterator.hasNext) chunkBuilder += paramsIterator.next() } while (paramsIterator.hasNext) chunkBuilder += paramsIterator.next() diff --git a/core/src/main/scala/zio/jdbc/package.scala b/core/src/main/scala/zio/jdbc/package.scala index 3f8898d8..649c38ee 100644 --- a/core/src/main/scala/zio/jdbc/package.scala +++ b/core/src/main/scala/zio/jdbc/package.scala @@ -24,7 +24,7 @@ package object jdbc { /** * Converts a String into a pure SQL expression */ - implicit def stringToSql(s: String): SqlFragment = SqlFragment(Chunk(SqlFragment.Segment.Syntax(s))) + implicit def stringToSql(s: String): SqlFragment = SqlFragment(s) /** * A new transaction, which may be applied to ZIO effects that require a diff --git a/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala b/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala index a5a0daa0..4785e625 100644 --- a/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala +++ b/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala @@ -29,10 +29,15 @@ object SqlFragmentSpec extends ZIOSpecDefault { s"Sql(select name, age from users where id = ?, $id)" ) } + - test("ensure no empty Syntax instances") { + test("Empty Segment instances are insignificant") { val age = 42 val name = "sholmes" - assertTrue(sql"select name, age from users where age = $age and name = $name".segments.size == 4) + val sql = sql"select name, age from users where age = $age and name = $name" + assertTrue( + sql.segments.size == 5, + sql.segments.last == SqlFragment.Segment.Empty, // Empty Syntax instances are insignificant + sql.toString == "Sql(select name, age from users where age = ? and name = ?, 42, sholmes)" + ) } + test("interpolate Sql values") { val tableName = sql"table1" @@ -247,6 +252,15 @@ object SqlFragmentSpec extends ZIOSpecDefault { assertTrue( result.toString == "Sql(UPDATE persons)" ) + } + + test("'interpolation <=> ++ operator' equivalence") { + val s1 = sql"${"1"}::varchar" + val s2 = (sql"${"1"}" ++ sql"::varchar") + + assertTrue( + s1.toString == "Sql(?::varchar, 1)", + s1.toString == s2.toString + ) } } + suite("SqlFragment ResultSet tests") {