Skip to content

Commit

Permalink
zio#439: Add support for var-arg db functions
Browse files Browse the repository at this point in the history
  • Loading branch information
jgoday committed May 4, 2021
1 parent 79c8045 commit 51b11c9
Show file tree
Hide file tree
Showing 6 changed files with 63 additions and 1 deletion.
13 changes: 12 additions & 1 deletion core/jvm/src/main/scala/zio/sql/expr.scala
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,11 @@ trait ExprModule extends NewtypesModule with FeaturesModule with OpsModule {
] {
def typeTag: TypeTag[Z] = implicitly[TypeTag[Z]]
}

sealed case class FunctionCallN[Z: TypeTag](param: Seq[Expr[_, _, _]], function: FunctionDefN[Z])
extends InvariantExpr[Any, Any, Z] {
def typeTag: TypeTag[Z] = implicitly[TypeTag[Z]]
}
}

sealed case class AggregationDef[-A, +B](name: FunctionName) { self =>
Expand Down Expand Up @@ -375,8 +380,12 @@ trait ExprModule extends NewtypesModule with FeaturesModule with OpsModule {
}
}

object FunctionDef {
sealed case class FunctionDefN[+B](name: FunctionName) { self =>
def apply[B1 >: B](param1: Expr[_, _, _]*)(implicit typeTag: TypeTag[B1]): Expr[Any, Any, B1] =
Expr.FunctionCallN(param1, self: FunctionDefN[B1])
}

object FunctionDef {
//math functions
val Abs = FunctionDef[Double, Double](FunctionName("abs"))
val Acos = FunctionDef[Double, Double](FunctionName("acos"))
Expand Down Expand Up @@ -418,6 +427,8 @@ trait ExprModule extends NewtypesModule with FeaturesModule with OpsModule {

// date functions
val CurrentTimestamp = FunctionDef[Nothing, Instant](FunctionName("current_timestamp"))

def variadicFunc[R](name: String) = FunctionDefN[R](FunctionName(name))
}

sealed trait Set[F, -A] {
Expand Down
8 changes: 8 additions & 0 deletions mysql/src/main/scala/zio/sql/mysql/MysqlModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,14 @@ trait MysqlModule extends Jdbc { self =>
render(",")
renderExpr(param7)
render(")")
case Expr.FunctionCallN(params, function) =>
render(function.name.name)
render("(")
for { (p, i) <- params.zipWithIndex } {
if (i > 0) render(",")
renderExpr(p)
}
render(")")
}

private def renderLit[A, B](lit: self.Expr.Literal[_])(implicit render: Renderer): Unit = {
Expand Down
8 changes: 8 additions & 0 deletions oracle/src/main/scala/zio/sql/oracle/OracleModule.scala
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ trait OracleModule extends Jdbc { self =>
builder.append(",")
buildExpr(param7, builder)
val _ = builder.append(")")
case Expr.FunctionCallN(params, function) =>
builder.append(function.name.name)
builder.append("(")
for { (p, i) <- params.zipWithIndex } {
if (i > 0) builder.append(",")
buildExpr(p, builder)
}
val _ = builder.append(")")
}

def buildReadString(read: self.Read[_], builder: StringBuilder): Unit =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -443,6 +443,14 @@ trait PostgresModule extends Jdbc { self =>
render(",")
renderExpr(param7)
render(")")
case Expr.FunctionCallN(params, function) =>
render(function.name.name)
render("(")
for { (p, i) <- params.zipWithIndex } {
if (i > 0) render(",")
renderExpr(p)
}
render(")")
}

private[zio] def renderReadImpl(read: self.Read[_])(implicit render: Renderer): Unit =
Expand Down
19 changes: 19 additions & 0 deletions postgres/src/test/scala/zio/sql/postgresql/VarArgSelectSpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
package zio.sql.postgresql

import zio.test.Assertion._
import zio.test._

object VarArgSelectSpec extends DefaultRunnableSpec with PostgresModule with ShopSchema {
import Customers._
import FunctionDef._

def spec = suite("Var args calls")(
test("works with simple select") {
val customConcat = variadicFunc[String]("concat")
val query = select(customConcat("Name: ", fName, lName)) from customers
assert(renderRead(query)) {
equalTo("SELECT concat('Name: ',customers.first_name,customers.last_name) FROM customers")
}
}
)
}
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ trait SqlServerModule extends Jdbc { self =>
builder.append(",")
buildExpr(param7)
val _ = builder.append(")")
case Expr.FunctionCallN(params, function) =>
builder.append(function.name.name)
builder.append("(")
for { (p, i) <- params.zipWithIndex } {
if (i > 0) builder.append(",")
buildExpr(p)
}
val _ = builder.append(")")
}

def buildReadString(read: self.Read[_]): Unit =
Expand Down

0 comments on commit 51b11c9

Please sign in to comment.