From 9d9d91b8b5c3b563cb55aa6ab8a64b1804918fc2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=8BAndrzej=20Ressel?= Date: Sun, 7 May 2023 14:15:05 +0200 Subject: [PATCH 1/3] Support custom JdbcEncoder when creating sql segment --- core/src/main/scala/zio/jdbc/JdbcEncoder.scala | 3 +++ core/src/main/scala/zio/jdbc/SqlFragment.scala | 4 ---- core/src/main/scala/zio/jdbc/package.scala | 17 ++++++++++++++++- .../test/scala/zio/jdbc/SqlFragmentSpec.scala | 17 +++++++++++++++++ 4 files changed, 36 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/zio/jdbc/JdbcEncoder.scala b/core/src/main/scala/zio/jdbc/JdbcEncoder.scala index 13a18a05..81eb0a07 100644 --- a/core/src/main/scala/zio/jdbc/JdbcEncoder.scala +++ b/core/src/main/scala/zio/jdbc/JdbcEncoder.scala @@ -21,6 +21,9 @@ import zio.schema.{ Schema, StandardType } /** * A type class that describes the ability to convert a value of type `A` into * a fragment of SQL. This is useful for forming SQL insert statements. + * + * NOTE: Users should be careful when creating custom JdbcEncoders for already existing types. You should either + * also create implicit Setter instance or use AnyVal wrapper (see `Custom JdbcEncoder` test). */ trait JdbcEncoder[-A] { def encode(value: A): SqlFragment diff --git a/core/src/main/scala/zio/jdbc/SqlFragment.scala b/core/src/main/scala/zio/jdbc/SqlFragment.scala index d09ab43b..50ab0225 100644 --- a/core/src/main/scala/zio/jdbc/SqlFragment.scala +++ b/core/src/main/scala/zio/jdbc/SqlFragment.scala @@ -208,10 +208,6 @@ object SqlFragment { final case class Param(value: Any, setter: Setter[Any]) extends Segment final case class Nested(sql: SqlFragment) extends Segment - implicit def paramSegment[A](a: A)(implicit setter: Setter[A]): Segment.Param = - Segment.Param(a, setter.asInstanceOf[Setter[Any]]) - - implicit def nestedSqlSegment[A](sql: SqlFragment): Segment.Nested = Segment.Nested(sql) } trait Setter[A] { self => diff --git a/core/src/main/scala/zio/jdbc/package.scala b/core/src/main/scala/zio/jdbc/package.scala index 03daba92..7bb4912d 100644 --- a/core/src/main/scala/zio/jdbc/package.scala +++ b/core/src/main/scala/zio/jdbc/package.scala @@ -15,9 +15,12 @@ */ package zio +import zio.jdbc.SqlFragment.{ Segment, Setter } +import zio.jdbc.{ JdbcEncoder, SqlFragment } + import scala.language.implicitConversions -package object jdbc { +package object jdbc extends LowPriorityImplicits1 { implicit def sqlInterpolator(sc: StringContext): SqlInterpolator = new SqlInterpolator(sc) @@ -34,3 +37,15 @@ package object jdbc { ZLayer(ZIO.serviceWith[ZConnectionPool](_.transaction)).flatten } +trait LowPriorityImplicits1 extends LowPriorityImplicits2 { + + implicit def paramSegment[A](a: A)(implicit setter: Setter[A]): Segment.Param = + Segment.Param(a, setter.asInstanceOf[Setter[Any]]) + + implicit def nestedSqlSegment[A](sql: SqlFragment): Segment.Nested = Segment.Nested(sql) +} + +trait LowPriorityImplicits2 { + implicit def segmentFromJdbcEncoder[A](obj: A)(implicit encoder: JdbcEncoder[A]): Segment = + Segment.Nested(encoder.encode(obj)) +} diff --git a/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala b/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala index b8502a9d..1f0f2b80 100644 --- a/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala +++ b/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala @@ -6,6 +6,7 @@ import zio.test.Assertion._ import zio.test._ import java.sql.SQLException +import java.util.Base64 final case class Person(name: String, age: Int) final case class UserLogin(username: String, password: String) @@ -222,9 +223,25 @@ object SqlFragmentSpec extends ZIOSpecDefault { assertTrue( result.toString == "Sql(UPDATE persons)" ) + } + + test("Custom JdbcEncoder") { + + implicit val byteArrayjdbcEncoder: JdbcEncoder[Base64Array] = + value => s"FROM_BASE64('${Base64.getEncoder.encodeToString(value.arr)}')" + + val bytes = Base64Array(Array[Byte](1, 2, 3)) + + val result = sql"UPDATE foo SET bytes = $bytes" + + assertTrue( + result.toString == "Sql(UPDATE foo SET bytes = FROM_BASE64('AQID'))" + ) } } + + case class Base64Array(arr: Array[Byte]) extends AnyVal + } object Models { From 88bd154fd9d82f6ddc0ec3c49a415872a89804c4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=8BAndrzej=20Ressel?= Date: Tue, 9 May 2023 00:05:11 +0200 Subject: [PATCH 2/3] Support custom JDBCEncoder when creating sql segment Combine Setter into JdbcEncoder --- .../src/main/scala/zio/jdbc/JdbcEncoder.scala | 168 ++++++++++++------ .../src/main/scala/zio/jdbc/SqlFragment.scala | 34 ++-- .../src/main/scala/zio/jdbc/ZConnection.scala | 2 +- core/src/main/scala/zio/jdbc/package.scala | 17 +- .../test/scala/zio/jdbc/SqlFragmentSpec.scala | 11 +- .../scala/zio/jdbc/ZConnectionPoolSpec.scala | 8 +- 6 files changed, 149 insertions(+), 91 deletions(-) diff --git a/core/src/main/scala/zio/jdbc/JdbcEncoder.scala b/core/src/main/scala/zio/jdbc/JdbcEncoder.scala index 81eb0a07..1f193a24 100644 --- a/core/src/main/scala/zio/jdbc/JdbcEncoder.scala +++ b/core/src/main/scala/zio/jdbc/JdbcEncoder.scala @@ -16,73 +16,123 @@ package zio.jdbc import zio.Chunk +import zio.jdbc.SqlFragment.{ Segment, Setter } import zio.schema.{ Schema, StandardType } +import java.sql.PreparedStatement + /** * A type class that describes the ability to convert a value of type `A` into * a fragment of SQL. This is useful for forming SQL insert statements. - * - * NOTE: Users should be careful when creating custom JdbcEncoders for already existing types. You should either - * also create implicit Setter instance or use AnyVal wrapper (see `Custom JdbcEncoder` test). */ -trait JdbcEncoder[-A] { +trait JdbcEncoder[A] { def encode(value: A): SqlFragment + val setter: Option[Setter[A]] + + final def contramap[B](f: B => A): JdbcEncoder[B] = { + val that = this + new JdbcEncoder[B] { + override def encode(value: B): SqlFragment = that.encode(f(value)) - final def contramap[B](f: B => A): JdbcEncoder[B] = value => encode(f(value)) + override val setter: Option[Setter[B]] = that.setter.map(_.contramap(f)) + } + } } object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { def apply[A]()(implicit encoder: JdbcEncoder[A]): JdbcEncoder[A] = encoder - implicit val intEncoder: JdbcEncoder[Int] = value => sql"$value" - implicit val longEncoder: JdbcEncoder[Long] = value => sql"$value" - implicit val doubleEncoder: JdbcEncoder[Double] = value => sql"$value" - implicit val charEncoder: JdbcEncoder[Char] = value => sql"$value" - implicit val stringEncoder: JdbcEncoder[String] = value => sql"$value" - implicit val booleanEncoder: JdbcEncoder[Boolean] = value => sql"$value" - implicit val bigIntEncoder: JdbcEncoder[java.math.BigInteger] = value => sql"$value" - implicit val bigDecimalEncoder: JdbcEncoder[java.math.BigDecimal] = value => sql"$value" - implicit val bigDecimalEncoderScala: JdbcEncoder[scala.math.BigDecimal] = value => sql"$value" - implicit val shortEncoder: JdbcEncoder[Short] = value => sql"$value" - implicit val floatEncoder: JdbcEncoder[Float] = value => sql"$value" - implicit val byteEncoder: JdbcEncoder[Byte] = value => sql"$value" - implicit val byteArrayEncoder: JdbcEncoder[Array[Byte]] = value => sql"$value" - implicit val byteChunkEncoder: JdbcEncoder[Chunk[Byte]] = value => sql"$value" - implicit val blobEncoder: JdbcEncoder[java.sql.Blob] = value => sql"$value" - implicit val uuidEncoder: JdbcEncoder[java.util.UUID] = value => sql"$value" - - implicit def singleParamEncoder[A: SqlFragment.Setter]: JdbcEncoder[A] = value => sql"$value" - - // TODO: review for cases like Option of a tuple - def optionEncoder[A](implicit encoder: JdbcEncoder[A]): JdbcEncoder[Option[A]] = - value => value.fold(SqlFragment.nullLiteral)(encoder.encode) + /** + * Use caution when using this method. Returning interpolating string may result in SQL injection attacks + */ + def apply[A](onEncode: A => SqlFragment): JdbcEncoder[A] = new JdbcEncoder[A] { + override def encode(value: A): SqlFragment = onEncode(value) + + override val setter: Option[Setter[A]] = None + } + + def apply[A]( + onEncode: A => SqlFragment, + onValue: (PreparedStatement, Int, A) => Unit, + onNull: (PreparedStatement, Int) => Unit + ): JdbcEncoder[A] = new JdbcEncoder[A] { + override def encode(value: A): SqlFragment = onEncode(value) + + override val setter: Option[Setter[A]] = Some(new Setter[A] { + override def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Unit = onValue(ps, index, value) + + override def unsafeSetNull(ps: PreparedStatement, index: Int): Unit = onNull(ps, index) + }) + } + + private def withSetter[A](implicit setter: Setter[A]): JdbcEncoder[A] = + apply( + value => SqlFragment(Chunk.apply(Segment.Param(value, setter.asInstanceOf[Setter[Any]]))), + setter.unsafeSetValue, + setter.unsafeSetNull + ) + + implicit val intEncoder: JdbcEncoder[Int] = withSetter + implicit val longEncoder: JdbcEncoder[Long] = withSetter + implicit val doubleEncoder: JdbcEncoder[Double] = withSetter + implicit val charEncoder: JdbcEncoder[Char] = withSetter + implicit val stringEncoder: JdbcEncoder[String] = withSetter + implicit val booleanEncoder: JdbcEncoder[Boolean] = withSetter + implicit val bigIntEncoder: JdbcEncoder[java.math.BigInteger] = withSetter + implicit val bigDecimalEncoder: JdbcEncoder[java.math.BigDecimal] = withSetter + implicit val bigDecimalEncoderScala: JdbcEncoder[scala.math.BigDecimal] = withSetter + implicit val shortEncoder: JdbcEncoder[Short] = withSetter + implicit val floatEncoder: JdbcEncoder[Float] = withSetter + implicit val byteEncoder: JdbcEncoder[Byte] = withSetter + implicit val byteArrayEncoder: JdbcEncoder[Array[Byte]] = withSetter + implicit val byteChunkEncoder: JdbcEncoder[Chunk[Byte]] = withSetter + implicit val blobEncoder: JdbcEncoder[java.sql.Blob] = withSetter + implicit val uuidEncoder: JdbcEncoder[java.util.UUID] = withSetter + + private def optionParamSetter[A](implicit setter: Setter[A]): Setter[Option[A]] = + Setter( + (ps, i, value) => + value match { + case Some(value) => setter.unsafeSetValue(ps, i, value) + case None => setter.unsafeSetNull(ps, i) + }, + (ps, i) => setter.unsafeSetNull(ps, i) + ) + implicit def optionEncoder[A](implicit encoder: JdbcEncoder[A]): JdbcEncoder[Option[A]] = new JdbcEncoder[Option[A]] { + override def encode(value: Option[A]): SqlFragment = value.fold(SqlFragment.nullLiteral)(encoder.encode) + + override val setter: Option[Setter[Option[A]]] = encoder.setter.map(optionParamSetter(_)) + } implicit def tuple2Encoder[A: JdbcEncoder, B: JdbcEncoder]: JdbcEncoder[(A, B)] = - tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode(tuple._2) + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode(tuple._2)) implicit def tuple3Encoder[A: JdbcEncoder, B: JdbcEncoder, C: JdbcEncoder]: JdbcEncoder[(A, B, C)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) + ) implicit def tuple4Encoder[A: JdbcEncoder, B: JdbcEncoder, C: JdbcEncoder, D: JdbcEncoder] : JdbcEncoder[(A, B, C, D)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( tuple._4 ) + ) implicit def tuple5Encoder[A: JdbcEncoder, B: JdbcEncoder, C: JdbcEncoder, D: JdbcEncoder, E: JdbcEncoder] : JdbcEncoder[(A, B, C, D, E)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( tuple._4 ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) + ) implicit def tuple6Encoder[ A: JdbcEncoder, @@ -92,7 +142,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { E: JdbcEncoder, F: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -100,6 +150,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( tuple._6 ) + ) implicit def tuple7Encoder[ A: JdbcEncoder, @@ -110,7 +161,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { F: JdbcEncoder, G: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F, G)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -118,6 +169,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( tuple._6 ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) + ) implicit def tuple8Encoder[ A: JdbcEncoder, @@ -129,7 +181,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { G: JdbcEncoder, H: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F, G, H)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -139,6 +191,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( tuple._8 ) + ) implicit def tuple9Encoder[ A: JdbcEncoder, @@ -151,7 +204,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { H: JdbcEncoder, I: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -161,6 +214,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( tuple._8 ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) + ) implicit def tuple10Encoder[ A: JdbcEncoder, @@ -174,7 +228,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { I: JdbcEncoder, J: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -186,6 +240,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( tuple._10 ) + ) implicit def tuple11Encoder[ A: JdbcEncoder, @@ -200,7 +255,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { J: JdbcEncoder, K: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -212,6 +267,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( tuple._10 ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) + ) implicit def tuple12Encoder[ A: JdbcEncoder, @@ -227,7 +283,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { K: JdbcEncoder, L: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -241,6 +297,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( tuple._12 ) + ) implicit def tuple13Encoder[ A: JdbcEncoder, @@ -257,7 +314,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { L: JdbcEncoder, M: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -271,6 +328,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( tuple._12 ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) + ) implicit def tuple14Encoder[ A: JdbcEncoder, @@ -288,7 +346,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { M: JdbcEncoder, N: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -304,6 +362,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( tuple._14 ) + ) implicit def tuple15Encoder[ A: JdbcEncoder, @@ -322,7 +381,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { N: JdbcEncoder, O: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -338,6 +397,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( tuple._14 ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) + ) implicit def tuple16Encoder[ A: JdbcEncoder, @@ -357,7 +417,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { O: JdbcEncoder, P: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -375,6 +435,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) ++ SqlFragment.comma ++ JdbcEncoder[P]().encode( tuple._16 ) + ) implicit def tuple17Encoder[ A: JdbcEncoder, @@ -395,7 +456,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { P: JdbcEncoder, Q: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -413,6 +474,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) ++ SqlFragment.comma ++ JdbcEncoder[P]().encode( tuple._16 ) ++ SqlFragment.comma ++ JdbcEncoder[Q]().encode(tuple._17) + ) implicit def tuple18Encoder[ A: JdbcEncoder, @@ -434,7 +496,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { Q: JdbcEncoder, R: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -454,6 +516,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[Q]().encode(tuple._17) ++ SqlFragment.comma ++ JdbcEncoder[R]().encode( tuple._18 ) + ) implicit def tuple19Encoder[ A: JdbcEncoder, @@ -476,7 +539,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { R: JdbcEncoder, S: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -496,6 +559,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[Q]().encode(tuple._17) ++ SqlFragment.comma ++ JdbcEncoder[R]().encode( tuple._18 ) ++ SqlFragment.comma ++ JdbcEncoder[S]().encode(tuple._19) + ) implicit def tuple20Encoder[ A: JdbcEncoder, @@ -519,7 +583,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { S: JdbcEncoder, T: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -541,6 +605,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[S]().encode(tuple._19) ++ SqlFragment.comma ++ JdbcEncoder[T]().encode( tuple._20 ) + ) implicit def tuple21Encoder[ A: JdbcEncoder, @@ -565,7 +630,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { T: JdbcEncoder, U: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -587,6 +652,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[S]().encode(tuple._19) ++ SqlFragment.comma ++ JdbcEncoder[T]().encode( tuple._20 ) ++ SqlFragment.comma ++ JdbcEncoder[U]().encode(tuple._21) + ) implicit def tuple22Encoder[ A: JdbcEncoder, @@ -612,7 +678,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { U: JdbcEncoder, V: JdbcEncoder ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V)] = - tuple => + JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( tuple._2 ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( @@ -636,6 +702,7 @@ object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { ) ++ SqlFragment.comma ++ JdbcEncoder[U]().encode(tuple._21) ++ SqlFragment.comma ++ JdbcEncoder[V]().encode( tuple._22 ) + ) } trait JdbcEncoder0LowPriorityImplicits { self => @@ -697,10 +764,11 @@ trait JdbcEncoder0LowPriorityImplicits { self => throw JdbcEncoderError(s"Failed to encode schema ${schema}", new IllegalArgumentException) } - private[jdbc] def caseClassEncoder[A](fields: Chunk[Schema.Field[A, _]]): JdbcEncoder[A] = { (a: A) => + private[jdbc] def caseClassEncoder[A](fields: Chunk[Schema.Field[A, _]]): JdbcEncoder[A] = JdbcEncoder(a => fields.map { f => val encoder = self.fromSchema(f.schema.asInstanceOf[Schema[Any]]) encoder.encode(f.get(a)) }.reduce(_ ++ SqlFragment.comma ++ _) - } + ) + } diff --git a/core/src/main/scala/zio/jdbc/SqlFragment.scala b/core/src/main/scala/zio/jdbc/SqlFragment.scala index 50ab0225..7d6e5140 100644 --- a/core/src/main/scala/zio/jdbc/SqlFragment.scala +++ b/core/src/main/scala/zio/jdbc/SqlFragment.scala @@ -208,14 +208,22 @@ object SqlFragment { final case class Param(value: Any, setter: Setter[Any]) extends Segment final case class Nested(sql: SqlFragment) extends Segment + implicit def jdbcEncoderSegment[A](obj: A)(implicit encoder: JdbcEncoder[A]): Segment = + encoder.setter match { + case Some(value) => Segment.Param(obj, value.asInstanceOf[Setter[Any]]) + case None => encoder.encode(obj) + } + + implicit def nestedSqlSegment[A](sql: SqlFragment): Segment.Nested = Segment.Nested(sql) + } - trait Setter[A] { self => - def setValue(ps: PreparedStatement, index: Int, value: A): Unit - def setNull(ps: PreparedStatement, index: Int): Unit + trait Setter[-A] { self => + def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Unit + def unsafeSetNull(ps: PreparedStatement, index: Int): Unit final def contramap[B](f: B => A): Setter[B] = - Setter((ps, i, value) => self.setValue(ps, i, f(value)), (ps, i) => self.setNull(ps, i)) + Setter((ps, i, value) => self.unsafeSetValue(ps, i, f(value)), (ps, i) => self.unsafeSetNull(ps, i)) } object Setter { @@ -223,28 +231,28 @@ object SqlFragment { def apply[A](onValue: (PreparedStatement, Int, A) => Unit, onNull: (PreparedStatement, Int) => Unit): Setter[A] = new Setter[A] { - def setValue(ps: PreparedStatement, index: Int, value: A): Unit = onValue(ps, index, value) - def setNull(ps: PreparedStatement, index: Int): Unit = onNull(ps, index) + def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Unit = onValue(ps, index, value) + def unsafeSetNull(ps: PreparedStatement, index: Int): Unit = onNull(ps, index) } def forSqlType[A](onValue: (PreparedStatement, Int, A) => Unit, sqlType: Int): Setter[A] = new Setter[A] { - def setValue(ps: PreparedStatement, index: Int, value: A): Unit = onValue(ps, index, value) - def setNull(ps: PreparedStatement, index: Int): Unit = ps.setNull(index, sqlType) + def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Unit = onValue(ps, index, value) + def unsafeSetNull(ps: PreparedStatement, index: Int): Unit = ps.setNull(index, sqlType) } def other[A](onValue: (PreparedStatement, Int, A) => Unit, sqlType: String): Setter[A] = new Setter[A] { - def setValue(ps: PreparedStatement, index: Int, value: A): Unit = onValue(ps, index, value) - def setNull(ps: PreparedStatement, index: Int): Unit = ps.setNull(index, Types.OTHER, sqlType) + def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Unit = onValue(ps, index, value) + def unsafeSetNull(ps: PreparedStatement, index: Int): Unit = ps.setNull(index, Types.OTHER, sqlType) } implicit def optionParamSetter[A](implicit setter: Setter[A]): Setter[Option[A]] = Setter( (ps, i, value) => value match { - case Some(value) => setter.setValue(ps, i, value) - case None => setter.setNull(ps, i) + case Some(value) => setter.unsafeSetValue(ps, i, value) + case None => setter.unsafeSetNull(ps, i) }, - (ps, i) => setter.setNull(ps, i) + (ps, i) => setter.unsafeSetNull(ps, i) ) implicit val intSetter: Setter[Int] = forSqlType((ps, i, value) => ps.setInt(i, value), Types.INTEGER) diff --git a/core/src/main/scala/zio/jdbc/ZConnection.scala b/core/src/main/scala/zio/jdbc/ZConnection.scala index e2ae6d8f..4e702557 100644 --- a/core/src/main/scala/zio/jdbc/ZConnection.scala +++ b/core/src/main/scala/zio/jdbc/ZConnection.scala @@ -46,7 +46,7 @@ final class ZConnection(private[jdbc] val connection: Connection) extends AnyVal _ <- ZIO.attempt { var paramIndex = 1 sql.foreachSegment(_ => ()) { param => - param.setter.setValue(statement, paramIndex, param.value) + param.setter.unsafeSetValue(statement, paramIndex, param.value) paramIndex += 1 } } diff --git a/core/src/main/scala/zio/jdbc/package.scala b/core/src/main/scala/zio/jdbc/package.scala index 7bb4912d..03daba92 100644 --- a/core/src/main/scala/zio/jdbc/package.scala +++ b/core/src/main/scala/zio/jdbc/package.scala @@ -15,12 +15,9 @@ */ package zio -import zio.jdbc.SqlFragment.{ Segment, Setter } -import zio.jdbc.{ JdbcEncoder, SqlFragment } - import scala.language.implicitConversions -package object jdbc extends LowPriorityImplicits1 { +package object jdbc { implicit def sqlInterpolator(sc: StringContext): SqlInterpolator = new SqlInterpolator(sc) @@ -37,15 +34,3 @@ package object jdbc extends LowPriorityImplicits1 { ZLayer(ZIO.serviceWith[ZConnectionPool](_.transaction)).flatten } -trait LowPriorityImplicits1 extends LowPriorityImplicits2 { - - implicit def paramSegment[A](a: A)(implicit setter: Setter[A]): Segment.Param = - Segment.Param(a, setter.asInstanceOf[Setter[Any]]) - - implicit def nestedSqlSegment[A](sql: SqlFragment): Segment.Nested = Segment.Nested(sql) -} - -trait LowPriorityImplicits2 { - implicit def segmentFromJdbcEncoder[A](obj: A)(implicit encoder: JdbcEncoder[A]): Segment = - Segment.Nested(encoder.encode(obj)) -} diff --git a/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala b/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala index 1f0f2b80..008f0604 100644 --- a/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala +++ b/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala @@ -46,13 +46,10 @@ object SqlFragmentSpec extends ZIOSpecDefault { } + test("type safe interpolation") { final case class Foo(value: String) - implicit val fooParamSetter: SqlFragment.Setter[Foo] = SqlFragment.Setter[String]().contramap(_.toString) + implicit val fooParamSetter: JdbcEncoder[Foo] = JdbcEncoder[String]().contramap(_.toString) val testSql = sql"${Foo("test")}" - assertTrue(testSql.segments.collect { case SqlFragment.Segment.Param(_, setter) => - setter - }.head eq fooParamSetter) && assertTrue(testSql.toString == "Sql(?, Foo(test))") } + suite(" SqlFragment.ParamSetter instances") { // TODO figure out how to test at PrepareStatement level @@ -226,10 +223,10 @@ object SqlFragmentSpec extends ZIOSpecDefault { } + test("Custom JdbcEncoder") { - implicit val byteArrayjdbcEncoder: JdbcEncoder[Base64Array] = - value => s"FROM_BASE64('${Base64.getEncoder.encodeToString(value.arr)}')" + implicit val byteArrayjdbcEncoder: JdbcEncoder[Array[Byte]] = + JdbcEncoder(value => s"FROM_BASE64('${Base64.getEncoder.encodeToString(value)}')") - val bytes = Base64Array(Array[Byte](1, 2, 3)) + val bytes = Array[Byte](1, 2, 3) val result = sql"UPDATE foo SET bytes = $bytes" diff --git a/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala b/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala index 8cb20fdd..7c063bf1 100644 --- a/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala +++ b/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala @@ -117,10 +117,10 @@ object ZConnectionPoolSpec extends ZIOSpecDefault { implicit val jdbcDecoder: JdbcDecoder[User] = JdbcDecoder[(String, Int)]().map[User](t => User(t._1, t._2)) - implicit val jdbcEncoder: JdbcEncoder[User] = (value: User) => { + implicit val jdbcEncoder: JdbcEncoder[User] = JdbcEncoder { value => val name = value.name val age = value.age - sql"""${name}""" ++ ", " ++ s"${age}" + sql"""$name""" ++ ", " ++ s"$age" } } @@ -130,10 +130,10 @@ object ZConnectionPoolSpec extends ZIOSpecDefault { implicit val jdbcDecoder: JdbcDecoder[UserNoId] = JdbcDecoder[(String, Int)]().map[UserNoId](t => UserNoId(t._1, t._2)) - implicit val jdbcEncoder: JdbcEncoder[UserNoId] = (value: UserNoId) => { + implicit val jdbcEncoder: JdbcEncoder[UserNoId] = JdbcEncoder { value => val name = value.name val age = value.age - sql"""${name}""" ++ ", " ++ s"${age}" + sql"""$name""" ++ ", " ++ s"$age" } } From ea090eafbc37bd10c9de4aa06aa0a35ea8e882b7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=8BAndrzej=20Ressel?= Date: Sun, 14 May 2023 20:09:38 +0200 Subject: [PATCH 3/3] Combine JDBCEncoder and Setter --- .../src/main/scala/zio/jdbc/JdbcEncoder.scala | 871 +++++------------- .../src/main/scala/zio/jdbc/SqlFragment.scala | 132 +-- .../src/main/scala/zio/jdbc/ZConnection.scala | 17 +- .../test/scala/zio/jdbc/SqlFragmentSpec.scala | 53 +- .../scala/zio/jdbc/ZConnectionPoolSpec.scala | 58 +- 5 files changed, 323 insertions(+), 808 deletions(-) diff --git a/core/src/main/scala/zio/jdbc/JdbcEncoder.scala b/core/src/main/scala/zio/jdbc/JdbcEncoder.scala index 1f193a24..e0cd9df3 100644 --- a/core/src/main/scala/zio/jdbc/JdbcEncoder.scala +++ b/core/src/main/scala/zio/jdbc/JdbcEncoder.scala @@ -16,693 +16,271 @@ package zio.jdbc import zio.Chunk -import zio.jdbc.SqlFragment.{ Segment, Setter } +import zio.jdbc.SqlFragment.{ Segment, values } import zio.schema.{ Schema, StandardType } -import java.sql.PreparedStatement +import java.sql.{ PreparedStatement, Types } /** * A type class that describes the ability to convert a value of type `A` into * a fragment of SQL. This is useful for forming SQL insert statements. */ -trait JdbcEncoder[A] { +trait JdbcEncoder[-A] { self => def encode(value: A): SqlFragment - val setter: Option[Setter[A]] - final def contramap[B](f: B => A): JdbcEncoder[B] = { - val that = this + /** + * Returns first index after encoder + */ + def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Int + + /** + * Returns first index after encoder + */ + def unsafeSetNull(ps: PreparedStatement, index: Int): Int + + final def contramap[B](f: B => A): JdbcEncoder[B] = new JdbcEncoder[B] { - override def encode(value: B): SqlFragment = that.encode(f(value)) + override def encode(value: B): SqlFragment = + self.encode(f(value)) + + override def unsafeSetValue(ps: PreparedStatement, index: Int, value: B): Int = + self.unsafeSetValue(ps, index, f(value)) + + override def unsafeSetNull(ps: PreparedStatement, index: Int): Int = self.unsafeSetNull(ps, index) - override val setter: Option[Setter[B]] = that.setter.map(_.contramap(f)) + override def sql(a: B): String = "?" } - } + + def sql(value: A): String + def prettyValuePrinter(value: A): String = value.toString } object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { - def apply[A]()(implicit encoder: JdbcEncoder[A]): JdbcEncoder[A] = encoder /** - * Use caution when using this method. Returning interpolating string may result in SQL injection attacks + * Restrictions: + * - Placeholder must contain only one "?" */ - def apply[A](onEncode: A => SqlFragment): JdbcEncoder[A] = new JdbcEncoder[A] { - override def encode(value: A): SqlFragment = onEncode(value) + def single[A](placeholder: String, valueToString: A => String): JdbcEncoder[A] = { + val questionMarkCount = placeholder.count(_ == '?') + if (questionMarkCount == 0) { + throw new IllegalArgumentException("Placeholder must contain one '?'") + } + if (questionMarkCount > 1) { + throw new IllegalArgumentException("Placeholder can contain only one '?'") + } + new JdbcEncoder[A] { + + override def encode(value: A): SqlFragment = + SqlFragment(Chunk(SqlFragment.Segment.Syntax.apply(placeholder))) + + override def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Int = { + ps.setString(index, valueToString(value)) + index + 1 + } - override val setter: Option[Setter[A]] = None + override def unsafeSetNull(ps: PreparedStatement, index: Int): Int = { + ps.setNull(index, Types.VARCHAR) + index + 1 + } + + override def sql(value: A): String = placeholder + + override def prettyValuePrinter(value: A): String = valueToString(value) + } } + def apply[A]()(implicit encoder: JdbcEncoder[A]): JdbcEncoder[A] = encoder + def apply[A]( onEncode: A => SqlFragment, + onSql: => String, onValue: (PreparedStatement, Int, A) => Unit, onNull: (PreparedStatement, Int) => Unit - ): JdbcEncoder[A] = new JdbcEncoder[A] { - override def encode(value: A): SqlFragment = onEncode(value) + ): JdbcEncoder[A] = + new JdbcEncoder[A] { - override val setter: Option[Setter[A]] = Some(new Setter[A] { - override def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Unit = onValue(ps, index, value) + override def encode(value: A): SqlFragment = onEncode(value) - override def unsafeSetNull(ps: PreparedStatement, index: Int): Unit = onNull(ps, index) - }) - } + override def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Int = { + onValue(ps, index, value) + index + 1 + } - private def withSetter[A](implicit setter: Setter[A]): JdbcEncoder[A] = - apply( - value => SqlFragment(Chunk.apply(Segment.Param(value, setter.asInstanceOf[Setter[Any]]))), - setter.unsafeSetValue, - setter.unsafeSetNull - ) + override def unsafeSetNull(ps: PreparedStatement, index: Int): Int = { + onNull(ps, index) + index + 1 + } - implicit val intEncoder: JdbcEncoder[Int] = withSetter - implicit val longEncoder: JdbcEncoder[Long] = withSetter - implicit val doubleEncoder: JdbcEncoder[Double] = withSetter - implicit val charEncoder: JdbcEncoder[Char] = withSetter - implicit val stringEncoder: JdbcEncoder[String] = withSetter - implicit val booleanEncoder: JdbcEncoder[Boolean] = withSetter - implicit val bigIntEncoder: JdbcEncoder[java.math.BigInteger] = withSetter - implicit val bigDecimalEncoder: JdbcEncoder[java.math.BigDecimal] = withSetter - implicit val bigDecimalEncoderScala: JdbcEncoder[scala.math.BigDecimal] = withSetter - implicit val shortEncoder: JdbcEncoder[Short] = withSetter - implicit val floatEncoder: JdbcEncoder[Float] = withSetter - implicit val byteEncoder: JdbcEncoder[Byte] = withSetter - implicit val byteArrayEncoder: JdbcEncoder[Array[Byte]] = withSetter - implicit val byteChunkEncoder: JdbcEncoder[Chunk[Byte]] = withSetter - implicit val blobEncoder: JdbcEncoder[java.sql.Blob] = withSetter - implicit val uuidEncoder: JdbcEncoder[java.util.UUID] = withSetter - - private def optionParamSetter[A](implicit setter: Setter[A]): Setter[Option[A]] = - Setter( - (ps, i, value) => - value match { - case Some(value) => setter.unsafeSetValue(ps, i, value) - case None => setter.unsafeSetNull(ps, i) - }, - (ps, i) => setter.unsafeSetNull(ps, i) - ) - implicit def optionEncoder[A](implicit encoder: JdbcEncoder[A]): JdbcEncoder[Option[A]] = new JdbcEncoder[Option[A]] { - override def encode(value: Option[A]): SqlFragment = value.fold(SqlFragment.nullLiteral)(encoder.encode) + override def sql(a: A): String = onSql + } - override val setter: Option[Setter[Option[A]]] = encoder.setter.map(optionParamSetter(_)) - } + private def forSqlType[A](onValue: (PreparedStatement, Int, A) => Unit, sqlType: Int): JdbcEncoder[A] = + new JdbcEncoder[A] { - implicit def tuple2Encoder[A: JdbcEncoder, B: JdbcEncoder]: JdbcEncoder[(A, B)] = - JdbcEncoder(tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode(tuple._2)) + override def encode(value: A): SqlFragment = SqlFragment( + Chunk.apply(Segment.Param(value, this.asInstanceOf[JdbcEncoder[Any]])) + ) - implicit def tuple3Encoder[A: JdbcEncoder, B: JdbcEncoder, C: JdbcEncoder]: JdbcEncoder[(A, B, C)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) - ) + def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Int = { + onValue(ps, index, value) + index + 1 + } - implicit def tuple4Encoder[A: JdbcEncoder, B: JdbcEncoder, C: JdbcEncoder, D: JdbcEncoder] - : JdbcEncoder[(A, B, C, D)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) - ) + def unsafeSetNull(ps: PreparedStatement, index: Int): Int = { + ps.setNull(index, sqlType) + index + 1 + } - implicit def tuple5Encoder[A: JdbcEncoder, B: JdbcEncoder, C: JdbcEncoder, D: JdbcEncoder, E: JdbcEncoder] - : JdbcEncoder[(A, B, C, D, E)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) - ) + override def sql(a: A): String = "?" + } + + private def forIterableSqlType[A, I]( + iterator: I => Iterator[A], + sqlType: Int + )(implicit encoder: JdbcEncoder[A]): JdbcEncoder[I] = + new JdbcEncoder[I] { - implicit def tuple6Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 + override def encode(value: I): SqlFragment = SqlFragment( + Chunk.apply(Segment.Param(value, this.asInstanceOf[JdbcEncoder[Any]])) ) - ) - implicit def tuple7Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) - ) + def unsafeSetValue(ps: PreparedStatement, index: Int, value: I): Int = + iterator(value) + .foldLeft(index) { case (i, value) => encoder.unsafeSetValue(ps, i, value) } - implicit def tuple8Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) - ) + def unsafeSetNull(ps: PreparedStatement, index: Int): Int = { + ps.setNull(index, sqlType) + index + 1 + } - implicit def tuple9Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) - ) + override def sql(a: I): String = iterator(a).map(_ => "?").mkString(",") - implicit def tuple10Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) - ) + override def prettyValuePrinter(value: I): String = iterator(value) + .mkString(", ") + } - implicit def tuple11Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) - ) + def other[A](onValue: (PreparedStatement, Int, A) => Unit, sqlType: String): JdbcEncoder[A] = new JdbcEncoder[A] { - implicit def tuple12Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) + override def encode(value: A): SqlFragment = SqlFragment( + Chunk.apply(Segment.Param(value, this.asInstanceOf[JdbcEncoder[Any]])) ) - implicit def tuple13Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) - ) + def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Int = { + onValue(ps, index, value) + index + 1 + } - implicit def tuple14Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) - ) + def unsafeSetNull(ps: PreparedStatement, index: Int): Int = { + ps.setNull(index, Types.OTHER, sqlType) + index + 1 + } - implicit def tuple15Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder, - O: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) - ) + override def sql(a: A): String = "?" + } - implicit def tuple16Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder, - O: JdbcEncoder, - P: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) ++ SqlFragment.comma ++ JdbcEncoder[P]().encode( - tuple._16 - ) + implicit def optionParamEncoder[A](implicit encoder: JdbcEncoder[A]): JdbcEncoder[Option[A]] = + JdbcEncoder( + _.fold(SqlFragment.nullLiteral)(encoder.encode), + "?", + (ps, i, value) => + value match { + case Some(value) => encoder.unsafeSetValue(ps, i, value) + case None => encoder.unsafeSetNull(ps, i) + }, + (ps, i) => encoder.unsafeSetNull(ps, i) ) - implicit def tuple17Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder, - O: JdbcEncoder, - P: JdbcEncoder, - Q: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) ++ SqlFragment.comma ++ JdbcEncoder[P]().encode( - tuple._16 - ) ++ SqlFragment.comma ++ JdbcEncoder[Q]().encode(tuple._17) + implicit val intEncoder: JdbcEncoder[Int] = forSqlType((ps, i, value) => ps.setInt(i, value), Types.INTEGER) + implicit val longEncoder: JdbcEncoder[Long] = forSqlType((ps, i, value) => ps.setLong(i, value), Types.BIGINT) + implicit val doubleEncoder: JdbcEncoder[Double] = forSqlType((ps, i, value) => ps.setDouble(i, value), Types.DOUBLE) + implicit val stringEncoder: JdbcEncoder[String] = forSqlType((ps, i, value) => ps.setString(i, value), Types.VARCHAR) + implicit val booleanEncoder: JdbcEncoder[Boolean] = + forSqlType((ps, i, value) => ps.setBoolean(i, value), Types.BOOLEAN) + implicit val shortEncoder: JdbcEncoder[Short] = forSqlType((ps, i, value) => ps.setShort(i, value), Types.SMALLINT) + implicit val floatEncoder: JdbcEncoder[Float] = forSqlType((ps, i, value) => ps.setFloat(i, value), Types.FLOAT) + implicit val byteEncoder: JdbcEncoder[Byte] = forSqlType((ps, i, value) => ps.setByte(i, value), Types.TINYINT) + implicit val byteArrayEncoder: JdbcEncoder[Array[Byte]] = + forSqlType((ps, i, value) => ps.setBytes(i, value), Types.ARRAY) + implicit val blobEncoder: JdbcEncoder[java.sql.Blob] = forSqlType((ps, i, value) => ps.setBlob(i, value), Types.BLOB) + implicit val sqlDateEncoder: JdbcEncoder[java.sql.Date] = + forSqlType((ps, i, value) => ps.setDate(i, value), Types.DATE) + implicit val sqlTimeEncoder: JdbcEncoder[java.sql.Time] = + forSqlType((ps, i, value) => ps.setTime(i, value), Types.TIME) + + implicit def chunkEncoder[A](implicit encoder: JdbcEncoder[A]): JdbcEncoder[Chunk[A]] = iterableEncoder[A, Chunk[A]] + + implicit def listEncoder[A](implicit encoder: JdbcEncoder[A]): JdbcEncoder[List[A]] = iterableEncoder[A, List[A]] + + implicit def vectorEncoder[A](implicit encoder: JdbcEncoder[A]): JdbcEncoder[Vector[A]] = + iterableEncoder[A, Vector[A]] + + implicit def setEncoder[A](implicit encoder: JdbcEncoder[A]): JdbcEncoder[Set[A]] = iterableEncoder[A, Set[A]] + + implicit def arrayEncoder[A](implicit encoder: JdbcEncoder[A]): JdbcEncoder[Array[A]] = + forIterableSqlType( + _.iterator, + Types.OTHER ) - implicit def tuple18Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder, - O: JdbcEncoder, - P: JdbcEncoder, - Q: JdbcEncoder, - R: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) ++ SqlFragment.comma ++ JdbcEncoder[P]().encode( - tuple._16 - ) ++ SqlFragment.comma ++ JdbcEncoder[Q]().encode(tuple._17) ++ SqlFragment.comma ++ JdbcEncoder[R]().encode( - tuple._18 - ) + private def iterableEncoder[A, I <: Iterable[A]](implicit encoder: JdbcEncoder[A]): JdbcEncoder[I] = + forIterableSqlType( + _.iterator, + Types.OTHER ) - implicit def tuple19Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder, - O: JdbcEncoder, - P: JdbcEncoder, - Q: JdbcEncoder, - R: JdbcEncoder, - S: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) ++ SqlFragment.comma ++ JdbcEncoder[P]().encode( - tuple._16 - ) ++ SqlFragment.comma ++ JdbcEncoder[Q]().encode(tuple._17) ++ SqlFragment.comma ++ JdbcEncoder[R]().encode( - tuple._18 - ) ++ SqlFragment.comma ++ JdbcEncoder[S]().encode(tuple._19) - ) + implicit val bigDecimalEncoder: JdbcEncoder[java.math.BigDecimal] = + forSqlType((ps, i, value) => ps.setBigDecimal(i, value), Types.NUMERIC) + implicit val sqlTimestampEncoder: JdbcEncoder[java.sql.Timestamp] = + forSqlType((ps, i, value) => ps.setTimestamp(i, value), Types.TIMESTAMP) - implicit def tuple20Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder, - O: JdbcEncoder, - P: JdbcEncoder, - Q: JdbcEncoder, - R: JdbcEncoder, - S: JdbcEncoder, - T: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) ++ SqlFragment.comma ++ JdbcEncoder[P]().encode( - tuple._16 - ) ++ SqlFragment.comma ++ JdbcEncoder[Q]().encode(tuple._17) ++ SqlFragment.comma ++ JdbcEncoder[R]().encode( - tuple._18 - ) ++ SqlFragment.comma ++ JdbcEncoder[S]().encode(tuple._19) ++ SqlFragment.comma ++ JdbcEncoder[T]().encode( - tuple._20 - ) - ) + implicit val uuidParamEncoder: JdbcEncoder[java.util.UUID] = other((ps, i, value) => ps.setObject(i, value), "uuid") - implicit def tuple21Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder, - O: JdbcEncoder, - P: JdbcEncoder, - Q: JdbcEncoder, - R: JdbcEncoder, - S: JdbcEncoder, - T: JdbcEncoder, - U: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) ++ SqlFragment.comma ++ JdbcEncoder[P]().encode( - tuple._16 - ) ++ SqlFragment.comma ++ JdbcEncoder[Q]().encode(tuple._17) ++ SqlFragment.comma ++ JdbcEncoder[R]().encode( - tuple._18 - ) ++ SqlFragment.comma ++ JdbcEncoder[S]().encode(tuple._19) ++ SqlFragment.comma ++ JdbcEncoder[T]().encode( - tuple._20 - ) ++ SqlFragment.comma ++ JdbcEncoder[U]().encode(tuple._21) - ) + implicit val charEncoder: JdbcEncoder[Char] = stringEncoder.contramap(_.toString) + implicit val bigIntEncoder: JdbcEncoder[java.math.BigInteger] = + bigDecimalEncoder.contramap(new java.math.BigDecimal(_)) + implicit val bigDecimalScalaEncoder: JdbcEncoder[scala.math.BigDecimal] = bigDecimalEncoder.contramap(_.bigDecimal) + implicit val byteChunkEncoder: JdbcEncoder[Chunk[Byte]] = byteArrayEncoder.contramap(_.toArray) + implicit val instantEncoder: JdbcEncoder[java.time.Instant] = sqlTimestampEncoder.contramap(java.sql.Timestamp.from) - implicit def tuple22Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder, - O: JdbcEncoder, - P: JdbcEncoder, - Q: JdbcEncoder, - R: JdbcEncoder, - S: JdbcEncoder, - T: JdbcEncoder, - U: JdbcEncoder, - V: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V)] = - JdbcEncoder(tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) ++ SqlFragment.comma ++ JdbcEncoder[P]().encode( - tuple._16 - ) ++ SqlFragment.comma ++ JdbcEncoder[Q]().encode(tuple._17) ++ SqlFragment.comma ++ JdbcEncoder[R]().encode( - tuple._18 - ) ++ SqlFragment.comma ++ JdbcEncoder[S]().encode(tuple._19) ++ SqlFragment.comma ++ JdbcEncoder[T]().encode( - tuple._20 - ) ++ SqlFragment.comma ++ JdbcEncoder[U]().encode(tuple._21) ++ SqlFragment.comma ++ JdbcEncoder[V]().encode( - tuple._22 - ) - ) + // TODO: review for cases like Option of a tuple + implicit def tuple2Encoder[A: JdbcEncoder, B: JdbcEncoder]: JdbcEncoder[(A, B)] = + tupleNEncoder(JdbcEncoder[A](), JdbcEncoder[B]()) + + implicit def tuple3Encoder[A: JdbcEncoder, B: JdbcEncoder, C: JdbcEncoder]: JdbcEncoder[(A, B, C)] = + tupleNEncoder(JdbcEncoder[A](), JdbcEncoder[B](), JdbcEncoder[C]()) + implicit def tuple4Encoder[A: JdbcEncoder, B: JdbcEncoder, C: JdbcEncoder, D: JdbcEncoder] + : JdbcEncoder[(A, B, C, D)] = + tupleNEncoder(JdbcEncoder[A](), JdbcEncoder[B](), JdbcEncoder[C](), JdbcEncoder[D]()) + + private def tupleNEncoder[A <: Product](encoders: JdbcEncoder[_]*): JdbcEncoder[A] = + new JdbcEncoder[A] { + override def encode(value: A): SqlFragment = + SqlFragment.intersperse( + SqlFragment.comma, + value.productIterator + .zip(encoders) + .map { case (value, encoder) => encoder.asInstanceOf[JdbcEncoder[Any]].encode(value) } + .toSeq + ) + + override def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Int = + value.productIterator + .zip(encoders) + .foldLeft(index) { case (i, (value, encoder)) => + encoder.asInstanceOf[JdbcEncoder[Any]].unsafeSetValue(ps, i, value) + } + + override def unsafeSetNull(ps: PreparedStatement, index: Int): Int = + encoders + .foldLeft(index) { case (i, encoder) => + encoder.asInstanceOf[JdbcEncoder[Any]].unsafeSetNull(ps, i) + } + + override def sql(a: A): String = Seq.fill(encoders.size)("?").mkString(", ") + + override def prettyValuePrinter(value: A): String = value.productIterator.mkString(", ") + } } trait JdbcEncoder0LowPriorityImplicits { self => @@ -719,7 +297,7 @@ trait JdbcEncoder0LowPriorityImplicits { self => case StandardType.BigIntegerType => JdbcEncoder.bigIntEncoder case StandardType.BinaryType => JdbcEncoder.byteChunkEncoder case StandardType.BigDecimalType => JdbcEncoder.bigDecimalEncoder - case StandardType.UUIDType => JdbcEncoder.uuidEncoder + case StandardType.UUIDType => JdbcEncoder.uuidParamEncoder // TODO: Standard Types which are missing are the date time types, not sure what would be the best way to handle them case _ => throw JdbcEncoderError(s"Unsupported type: $standardType", new IllegalArgumentException) } @@ -730,7 +308,7 @@ trait JdbcEncoder0LowPriorityImplicits { self => case Schema.Primitive(standardType, _) => primitiveCodec(standardType) case Schema.Optional(schema, _) => - JdbcEncoder.optionEncoder(self.fromSchema(schema)) + JdbcEncoder.optionParamEncoder(self.fromSchema(schema)) case Schema.Tuple2(left, right, _) => JdbcEncoder.tuple2Encoder(self.fromSchema(left), self.fromSchema(right)) // format: off @@ -764,11 +342,26 @@ trait JdbcEncoder0LowPriorityImplicits { self => throw JdbcEncoderError(s"Failed to encode schema ${schema}", new IllegalArgumentException) } - private[jdbc] def caseClassEncoder[A](fields: Chunk[Schema.Field[A, _]]): JdbcEncoder[A] = JdbcEncoder(a => - fields.map { f => - val encoder = self.fromSchema(f.schema.asInstanceOf[Schema[Any]]) - encoder.encode(f.get(a)) - }.reduce(_ ++ SqlFragment.comma ++ _) - ) + private[jdbc] def caseClassEncoder[A](fields: Chunk[Schema.Field[A, _]]): JdbcEncoder[A] = new JdbcEncoder[A] { + override def encode(a: A): SqlFragment = + fields.map { f => + val encoder = self.fromSchema(f.schema.asInstanceOf[Schema[Any]]) + encoder.encode(f.get(a)) + }.reduce(_ ++ SqlFragment.comma ++ _) + + override def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Int = + fields.foldLeft(index) { case (i, f) => + val encoder = self.fromSchema(f.schema.asInstanceOf[Schema[Any]]) + encoder.unsafeSetValue(ps, i, f.get(value)) + } + + override def unsafeSetNull(ps: PreparedStatement, index: Int): Int = + fields.foldLeft(index) { case (i, f) => + val encoder = self.fromSchema(f.schema.asInstanceOf[Schema[Any]]) + encoder.unsafeSetNull(ps, i) + } + + override def sql(a: A): String = Seq.fill(fields.iterator.size)("?").mkString(",") + } } diff --git a/core/src/main/scala/zio/jdbc/SqlFragment.scala b/core/src/main/scala/zio/jdbc/SqlFragment.scala index adf04194..b646e002 100644 --- a/core/src/main/scala/zio/jdbc/SqlFragment.scala +++ b/core/src/main/scala/zio/jdbc/SqlFragment.scala @@ -18,7 +18,6 @@ package zio.jdbc import zio._ import zio.jdbc.SqlFragment.Segment -import java.sql.{ PreparedStatement, Types } import scala.language.implicitConversions /** @@ -84,6 +83,8 @@ sealed trait SqlFragment { self => def notIn[B](b: B, bs: B*)(implicit encoder: JdbcEncoder[B]): SqlFragment = notIn(b +: bs) + def and(): SqlFragment = self ++ SqlFragment.and + def notIn[B](bs: Iterable[B])(implicit encoder: JdbcEncoder[B]): SqlFragment = in0(SqlFragment.notIn, bs) @@ -112,27 +113,8 @@ sealed trait SqlFragment { self => foreachSegment { syntax => sql.append(syntax.value) } { param => - param.value match { - case iterable: Iterable[_] => - iterable.iterator.foreach { item => - paramsBuilder += item.toString - } - sql.append( - Seq.fill(iterable.iterator.size)("?").mkString(",") - ) - - case array: Array[_] => - array.foreach { item => - paramsBuilder += item.toString - } - sql.append( - Seq.fill(array.length)("?").mkString(",") - ) - - case _ => - sql.append("?") - paramsBuilder += param.value.toString - } + sql.append(param.setter.sql(param.value)) + paramsBuilder += param.setter.prettyValuePrinter(param.value) } val params = paramsBuilder.result() @@ -246,105 +228,17 @@ object SqlFragment { sealed trait Segment object Segment { - final case class Syntax(value: String) extends Segment - final case class Param(value: Any, setter: Setter[Any]) extends Segment - final case class Nested(sql: SqlFragment) extends Segment - - implicit def jdbcEncoderSegment[A](obj: A)(implicit encoder: JdbcEncoder[A]): Segment = - encoder.setter match { - case Some(value) => Segment.Param(obj, value.asInstanceOf[Setter[Any]]) - case None => encoder.encode(obj) - } - - implicit def nestedSqlSegment[A](sql: SqlFragment): Segment.Nested = Segment.Nested(sql) - - } - - trait Setter[-A] { self => - def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Unit - def unsafeSetNull(ps: PreparedStatement, index: Int): Unit - - final def contramap[B](f: B => A): Setter[B] = - Setter((ps, i, value) => self.unsafeSetValue(ps, i, f(value)), (ps, i) => self.unsafeSetNull(ps, i)) - } - - object Setter { - def apply[A]()(implicit setter: Setter[A]): Setter[A] = setter - - def apply[A](onValue: (PreparedStatement, Int, A) => Unit, onNull: (PreparedStatement, Int) => Unit): Setter[A] = - new Setter[A] { - def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Unit = onValue(ps, index, value) - def unsafeSetNull(ps: PreparedStatement, index: Int): Unit = onNull(ps, index) - } - - def forSqlType[A](onValue: (PreparedStatement, Int, A) => Unit, sqlType: Int): Setter[A] = new Setter[A] { - def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Unit = onValue(ps, index, value) - def unsafeSetNull(ps: PreparedStatement, index: Int): Unit = ps.setNull(index, sqlType) - } - - def other[A](onValue: (PreparedStatement, Int, A) => Unit, sqlType: String): Setter[A] = new Setter[A] { - def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Unit = onValue(ps, index, value) - def unsafeSetNull(ps: PreparedStatement, index: Int): Unit = ps.setNull(index, Types.OTHER, sqlType) - } - - implicit def optionParamSetter[A](implicit setter: Setter[A]): Setter[Option[A]] = - Setter( - (ps, i, value) => - value match { - case Some(value) => setter.unsafeSetValue(ps, i, value) - case None => setter.unsafeSetNull(ps, i) - }, - (ps, i) => setter.unsafeSetNull(ps, i) - ) - - implicit val intSetter: Setter[Int] = forSqlType((ps, i, value) => ps.setInt(i, value), Types.INTEGER) - implicit val longSetter: Setter[Long] = forSqlType((ps, i, value) => ps.setLong(i, value), Types.BIGINT) - implicit val doubleSetter: Setter[Double] = forSqlType((ps, i, value) => ps.setDouble(i, value), Types.DOUBLE) - implicit val stringSetter: Setter[String] = forSqlType((ps, i, value) => ps.setString(i, value), Types.VARCHAR) - implicit val booleanSetter: Setter[Boolean] = forSqlType((ps, i, value) => ps.setBoolean(i, value), Types.BOOLEAN) - implicit val shortSetter: Setter[Short] = forSqlType((ps, i, value) => ps.setShort(i, value), Types.SMALLINT) - implicit val floatSetter: Setter[Float] = forSqlType((ps, i, value) => ps.setFloat(i, value), Types.FLOAT) - implicit val byteSetter: Setter[Byte] = forSqlType((ps, i, value) => ps.setByte(i, value), Types.TINYINT) - implicit val byteArraySetter: Setter[Array[Byte]] = forSqlType((ps, i, value) => ps.setBytes(i, value), Types.ARRAY) - implicit val blobSetter: Setter[java.sql.Blob] = forSqlType((ps, i, value) => ps.setBlob(i, value), Types.BLOB) - implicit val sqlDateSetter: Setter[java.sql.Date] = forSqlType((ps, i, value) => ps.setDate(i, value), Types.DATE) - implicit val sqlTimeSetter: Setter[java.sql.Time] = forSqlType((ps, i, value) => ps.setTime(i, value), Types.TIME) - - implicit def chunkSetter[A](implicit setter: Setter[A]): Setter[Chunk[A]] = iterableSetter[A, Chunk[A]] - implicit def listSetter[A](implicit setter: Setter[A]): Setter[List[A]] = iterableSetter[A, List[A]] - implicit def vectorSetter[A](implicit setter: Setter[A]): Setter[Vector[A]] = iterableSetter[A, Vector[A]] - implicit def setSetter[A](implicit setter: Setter[A]): Setter[Set[A]] = iterableSetter[A, Set[A]] - - implicit def arraySetter[A](implicit setter: Setter[A]): Setter[Array[A]] = - forSqlType( - (ps, i, iterable) => - iterable.zipWithIndex.foreach { case (value, valueIdx) => - setter.setValue(ps, i + valueIdx, value) - }, - Types.OTHER - ) + final case class Syntax(value: String) extends Segment + final case class Param(value: Any, setter: JdbcEncoder[Any]) extends Segment + final case class Nested(sql: SqlFragment) extends Segment - private def iterableSetter[A, I <: Iterable[A]](implicit setter: Setter[A]): Setter[I] = - forSqlType( - (ps, i, iterable) => - iterable.zipWithIndex.foreach { case (value, valueIdx) => - setter.setValue(ps, i + valueIdx, value) - }, - Types.OTHER - ) - - implicit val bigDecimalSetter: Setter[java.math.BigDecimal] = - forSqlType((ps, i, value) => ps.setBigDecimal(i, value), Types.NUMERIC) - implicit val sqlTimestampSetter: Setter[java.sql.Timestamp] = - forSqlType((ps, i, value) => ps.setTimestamp(i, value), Types.TIMESTAMP) + implicit def paramSegment[A](a: A)(implicit encoder: JdbcEncoder[A]): Segment.Param = + Segment.Param(a, encoder.asInstanceOf[JdbcEncoder[Any]]) - implicit val uuidParamSetter: Setter[java.util.UUID] = other((ps, i, value) => ps.setObject(i, value), "uuid") +// implicit def paramSegment[A](a: A)(implicit setter: JdbcEncoder[A]): Segment = +// Nested(setter.encode(a)) - implicit val charSetter: Setter[Char] = stringSetter.contramap(_.toString) - implicit val bigIntSetter: Setter[java.math.BigInteger] = bigDecimalSetter.contramap(new java.math.BigDecimal(_)) - implicit val bigDecimalScalaSetter: Setter[scala.math.BigDecimal] = bigDecimalSetter.contramap(_.bigDecimal) - implicit val byteChunkSetter: Setter[Chunk[Byte]] = byteArraySetter.contramap(_.toArray) - implicit val instantSetter: Setter[java.time.Instant] = sqlTimestampSetter.contramap(java.sql.Timestamp.from) + implicit def nestedSqlSegment[A](sql: SqlFragment): Segment.Nested = Segment.Nested(sql) } def apply(sql: String): SqlFragment = sql @@ -364,7 +258,7 @@ object SqlFragment { def update(table: String): SqlFragment = s"UPDATE $table" - private[jdbc] def intersperse( + def intersperse( sep: SqlFragment, elements: Iterable[SqlFragment] ): SqlFragment = { diff --git a/core/src/main/scala/zio/jdbc/ZConnection.scala b/core/src/main/scala/zio/jdbc/ZConnection.scala index f6167f7d..2a674c81 100644 --- a/core/src/main/scala/zio/jdbc/ZConnection.scala +++ b/core/src/main/scala/zio/jdbc/ZConnection.scala @@ -45,19 +45,7 @@ final class ZConnection(private[jdbc] val underlying: ZConnection.Restorable) ex statement <- ZIO.acquireRelease(ZIO.attempt { val sb = new StringBuilder() sql.foreachSegment(syntax => sb.append(syntax.value)) { param => - param.value match { - case iterable: Iterable[_] => - sb.append( - Seq.fill(iterable.iterator.size)("?").mkString(", ") - ) - - case array: Array[_] => - sb.append( - Seq.fill(array.length)("?").mkString(", ") - ) - - case _ => sb.append("?") - } + sb.append(param.setter.sql(param.value)) } transactionIsolationLevel.foreach { transactionIsolationLevel => connection.setTransactionIsolation(transactionIsolationLevel.toInt) @@ -67,8 +55,7 @@ final class ZConnection(private[jdbc] val underlying: ZConnection.Restorable) ex _ <- ZIO.attempt { var paramIndex = 1 sql.foreachSegment(_ => ()) { param => - param.setter.unsafeSetValue(statement, paramIndex, param.value) - paramIndex += 1 + paramIndex = param.setter.unsafeSetValue(statement, paramIndex, param.value) } } result <- f(statement) diff --git a/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala b/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala index d0538f2e..2ee9d889 100644 --- a/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala +++ b/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala @@ -1,9 +1,8 @@ package zio.jdbc import zio.Chunk -import zio.jdbc.SqlFragment.Setter -import zio.jdbc.{ transaction => transact } -import zio.schema.{ Schema, TypeId } +import zio.jdbc.{transaction => transact} +import zio.schema.{Schema, TypeId} import zio.test.Assertion._ import zio.test._ @@ -48,11 +47,28 @@ object SqlFragmentSpec extends ZIOSpecDefault { } + test("type safe interpolation") { final case class Foo(value: String) - implicit val fooParamSetter: JdbcEncoder[Foo] = JdbcEncoder[String]().contramap(_.toString) + implicit val fooParamJdbcEncoder: JdbcEncoder[Foo] = JdbcEncoder[String]().contramap(_.toString) val testSql = sql"${Foo("test")}" + assertTrue(testSql.segments.collect { case SqlFragment.Segment.Param(_, setter) => + setter + }.head eq fooParamJdbcEncoder) && assertTrue(testSql.toString == "Sql(?, Foo(test))") + } + test("Custom JdbcEncoder") { + + implicit val encoder: JdbcEncoder[Array[Byte]] = JdbcEncoder.single( + s"FROM_BASE64('?')", + value => Base64.getEncoder.encodeToString(value) + ) + + val bytes = Array[Byte](1, 2, 3) + + val result = sql"UPDATE foo SET bytes = $bytes" + + assertTrue( + result.toString == "Sql(UPDATE foo SET bytes = FROM_BASE64('?'), AQID)" + ) } + suite(" SqlFragment.ParamSetter instances") { // TODO figure out how to test at PrepareStatement level test("Option") { @@ -140,15 +156,24 @@ object SqlFragmentSpec extends ZIOSpecDefault { .toString == "Sql(select name, age from users where id IN (?,?,?), 1, 2, 3)" ) + } + test("fragment method where with multiple iterator params") { + val seq = Seq(1, 2, 3) + assertTrue( + sql"select name, age from users where id" + .in(seq) + .and() + .notIn(seq) + .toString == + "Sql(select name, age from users where id IN (?,?,?) AND NOT IN (?,?,?), 1, 2, 3, 1, 2, 3)" + ) } + test("interpolation param is supported collection") { - def assertIn[A: Setter](collection: A) = { + def assertIn[A: JdbcEncoder](collection: A) = { println(sql"select name, age from users where id in ($collection)".toString) assertTrue( sql"select name, age from users where id in ($collection)".toString == "Sql(select name, age from users where id in (?,?,?), 1, 2, 3)" ) } - assertIn(Chunk(1, 2, 3)) && assertIn(List(1, 2, 3)) && assertIn(Vector(1, 2, 3)) && @@ -246,25 +271,9 @@ object SqlFragmentSpec extends ZIOSpecDefault { assertTrue( result.toString == "Sql(UPDATE persons)" ) - } + - test("Custom JdbcEncoder") { - - implicit val byteArrayjdbcEncoder: JdbcEncoder[Array[Byte]] = - JdbcEncoder(value => s"FROM_BASE64('${Base64.getEncoder.encodeToString(value)}')") - - val bytes = Array[Byte](1, 2, 3) - - val result = sql"UPDATE foo SET bytes = $bytes" - - assertTrue( - result.toString == "Sql(UPDATE foo SET bytes = FROM_BASE64('AQID'))" - ) } } - - case class Base64Array(arr: Array[Byte]) extends AnyVal - } object Models { diff --git a/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala b/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala index 8a9cee60..6f82c661 100644 --- a/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala +++ b/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala @@ -1,7 +1,6 @@ package zio.jdbc -import zio._ -import zio.jdbc.SqlFragment.Setter +import zio.{ jdbc, _ } import zio.schema._ import zio.test.Assertion._ import zio.test.TestAspect._ @@ -127,11 +126,8 @@ object ZConnectionPoolSpec extends ZIOSpecDefault { implicit val jdbcDecoder: JdbcDecoder[User] = JdbcDecoder[(String, Int)]().map[User](t => User(t._1, t._2)) - implicit val jdbcEncoder: JdbcEncoder[User] = JdbcEncoder { value => - val name = value.name - val age = value.age - sql"""$name""" ++ ", " ++ s"$age" - } + implicit val jdbcEncoder: JdbcEncoder[User] = + JdbcEncoder[(String, Int)]().contramap[User](t => (t.name, t.age)) } final case class UserNoId(name: String, age: Int) @@ -140,11 +136,9 @@ object ZConnectionPoolSpec extends ZIOSpecDefault { implicit val jdbcDecoder: JdbcDecoder[UserNoId] = JdbcDecoder[(String, Int)]().map[UserNoId](t => UserNoId(t._1, t._2)) - implicit val jdbcEncoder: JdbcEncoder[UserNoId] = JdbcEncoder { value => - val name = value.name - val age = value.age - sql"""$name""" ++ ", " ++ s"$age" - } + implicit val jdbcEncoder: JdbcEncoder[UserNoId] = + JdbcEncoder[(String, Int)]().contramap[UserNoId](t => (t.name, t.age)) + } def spec: Spec[TestEnvironment, Any] = @@ -298,7 +292,7 @@ object ZConnectionPoolSpec extends ZIOSpecDefault { test("select all in") { val namesToSearch = Chunk(sherlockHolmes.name, johnDoe.name) - def assertUsersFound[A: Setter](collection: A) = + def assertUsersFound[A: JdbcEncoder](collection: A) = for { users <- transaction { sql"select name, age from users where name IN ($collection)".query[User].selectAll @@ -314,6 +308,32 @@ object ZConnectionPoolSpec extends ZIOSpecDefault { assertUsersFound(namesToSearch.toSet) && assertUsersFound(namesToSearch.toArray) + for { + _ <- createUsers *> insertSherlock *> insertWatson *> insertJohn + testResult <- asserttions + } yield testResult + } + test("select all in multiple lists") { + val namesToSearch = Chunk(sherlockHolmes.name, johnDoe.name, johnWatson.name) + val namesToAvoid = Chunk(johnWatson.name) + + def assertUsersFound[A: JdbcEncoder](collection: A, collection2: A) = + for { + users <- transaction { + sql"select name, age from users where name IN ($collection) and name NOT IN ($collection2)" + .query[User] + .selectAll + } + } yield assertTrue( + users.map(_.name) == Chunk(sherlockHolmes.name, johnDoe.name) + ) + + def asserttions = + assertUsersFound(namesToSearch, namesToAvoid) && + assertUsersFound(namesToSearch.toList, namesToAvoid.toList) && + assertUsersFound(namesToSearch.toVector, namesToAvoid.toVector) && + assertUsersFound(namesToSearch.toSet, namesToAvoid.toSet) && + assertUsersFound(namesToSearch.toArray, namesToAvoid.toArray) + for { _ <- createUsers *> insertSherlock *> insertWatson *> insertJohn testResult <- asserttions @@ -338,6 +358,18 @@ object ZConnectionPoolSpec extends ZIOSpecDefault { _ <- createUsers *> insertSherlock num <- transaction(sql"update users set age = 43 where name = ${sherlockHolmes.name}".update) } yield assertTrue(num == 1L) + } + test("select with custom encoder") { + implicit val encoder: JdbcEncoder[Array[Char]] = JdbcEncoder.single( + s"TRIM(?)", + value => new String(value) + ) + + val sherlockArray = " Sherlock Holmes ".toCharArray + + for { + _ <- createUsers *> insertSherlock + num <- transaction(sql"delete from users where name = ${sherlockArray}".delete) + } yield assertTrue(num == 1L) } } + suite("decoding") {