Skip to content

Commit

Permalink
Fix #172 (#173)
Browse files Browse the repository at this point in the history
* Fix #172

* scalafmt

* clean
  • Loading branch information
guizmaii authored Oct 27, 2023
1 parent 6205129 commit c925ab4
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 8 deletions.
6 changes: 5 additions & 1 deletion core/src/main/scala/zio/jdbc/SqlFragment.scala
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,7 @@ sealed trait SqlFragment { self =>

private[jdbc] def foreachSegment(addSyntax: Segment.Syntax => Any)(addParam: Segment.Param => Any): Unit =
segments.foreach {
case Segment.Empty => ()
case syntax: Segment.Syntax => addSyntax(syntax)
case param: Segment.Param => addParam(param)
case nested: Segment.Nested => nested.sql.foreachSegment(addSyntax)(addParam)
Expand All @@ -289,10 +290,13 @@ object SqlFragment {

sealed trait Segment
object Segment {
case object Empty extends Segment
final case class Syntax(value: String) extends Segment
final case class Param(value: Any, setter: Setter[Any]) extends Segment
final case class Nested(sql: SqlFragment) extends Segment

@inline def empty: Segment = Empty

implicit def paramSegment[A](a: A)(implicit setter: Setter[A]): Segment.Param =
Segment.Param(a, setter.asInstanceOf[Setter[Any]])

Expand Down Expand Up @@ -415,7 +419,7 @@ object SqlFragment {
forSqlType((ps, i, value) => ps.setObject(i, value, Types.TIMESTAMP_WITH_TIMEZONE), Types.TIMESTAMP_WITH_TIMEZONE)
}

def apply(sql: String): SqlFragment = sql
def apply(sql: String): SqlFragment = SqlFragment(Chunk.single(SqlFragment.Segment.Syntax(sql)))

def apply(segments: Chunk[Segment]): SqlFragment =
SqlFragment.Append(segments)
Expand Down
6 changes: 2 additions & 4 deletions core/src/main/scala/zio/jdbc/SqlInterpolator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@ final class SqlInterpolator(val context: StringContext) extends AnyVal {

while (syntaxIterator.hasNext) {
val syntax = syntaxIterator.next()
if (syntax.nonEmpty) {
chunkBuilder += SqlFragment.Segment.Syntax(syntax)
if (paramsIterator.hasNext) chunkBuilder += paramsIterator.next()
}
chunkBuilder += (if (syntax.isEmpty) SqlFragment.Segment.empty else SqlFragment.Segment.Syntax(syntax))
if (paramsIterator.hasNext) chunkBuilder += paramsIterator.next()
}
while (paramsIterator.hasNext)
chunkBuilder += paramsIterator.next()
Expand Down
2 changes: 1 addition & 1 deletion core/src/main/scala/zio/jdbc/package.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ package object jdbc {
/**
* Converts a String into a pure SQL expression
*/
implicit def stringToSql(s: String): SqlFragment = SqlFragment(Chunk(SqlFragment.Segment.Syntax(s)))
implicit def stringToSql(s: String): SqlFragment = SqlFragment(s)

/**
* A new transaction, which may be applied to ZIO effects that require a
Expand Down
18 changes: 16 additions & 2 deletions core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,15 @@ object SqlFragmentSpec extends ZIOSpecDefault {
s"Sql(select name, age from users where id = ?, $id)"
)
} +
test("ensure no empty Syntax instances") {
test("Empty Segment instances are insignificant") {
val age = 42
val name = "sholmes"
assertTrue(sql"select name, age from users where age = $age and name = $name".segments.size == 4)
val sql = sql"select name, age from users where age = $age and name = $name"
assertTrue(
sql.segments.size == 5,
sql.segments.last == SqlFragment.Segment.Empty,
sql.toString == "Sql(select name, age from users where age = ? and name = ?, 42, sholmes)"
)
} +
test("interpolate Sql values") {
val tableName = sql"table1"
Expand Down Expand Up @@ -247,6 +252,15 @@ object SqlFragmentSpec extends ZIOSpecDefault {
assertTrue(
result.toString == "Sql(UPDATE persons)"
)
} +
test("'interpolation <=> ++ operator' equivalence") {
val s1 = sql"${"1"}::varchar"
val s2 = (sql"${"1"}" ++ sql"::varchar")

assertTrue(
s1.toString == "Sql(?::varchar, 1)",
s1.toString == s2.toString
)
}
} +
suite("SqlFragment ResultSet tests") {
Expand Down

0 comments on commit c925ab4

Please sign in to comment.