diff --git a/core/src/main/scala/zio/jdbc/JdbcEncoder.scala b/core/src/main/scala/zio/jdbc/JdbcEncoder.scala index 13a18a05..e0cd9df3 100644 --- a/core/src/main/scala/zio/jdbc/JdbcEncoder.scala +++ b/core/src/main/scala/zio/jdbc/JdbcEncoder.scala @@ -16,623 +16,271 @@ package zio.jdbc import zio.Chunk +import zio.jdbc.SqlFragment.{ Segment, values } import zio.schema.{ Schema, StandardType } +import java.sql.{ PreparedStatement, Types } + /** * A type class that describes the ability to convert a value of type `A` into * a fragment of SQL. This is useful for forming SQL insert statements. */ -trait JdbcEncoder[-A] { +trait JdbcEncoder[-A] { self => def encode(value: A): SqlFragment - final def contramap[B](f: B => A): JdbcEncoder[B] = value => encode(f(value)) + /** + * Returns first index after encoder + */ + def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Int + + /** + * Returns first index after encoder + */ + def unsafeSetNull(ps: PreparedStatement, index: Int): Int + + final def contramap[B](f: B => A): JdbcEncoder[B] = + new JdbcEncoder[B] { + override def encode(value: B): SqlFragment = + self.encode(f(value)) + + override def unsafeSetValue(ps: PreparedStatement, index: Int, value: B): Int = + self.unsafeSetValue(ps, index, f(value)) + + override def unsafeSetNull(ps: PreparedStatement, index: Int): Int = self.unsafeSetNull(ps, index) + + override def sql(a: B): String = "?" + } + + def sql(value: A): String + def prettyValuePrinter(value: A): String = value.toString } object JdbcEncoder extends JdbcEncoder0LowPriorityImplicits { + + /** + * Restrictions: + * - Placeholder must contain only one "?" + */ + def single[A](placeholder: String, valueToString: A => String): JdbcEncoder[A] = { + val questionMarkCount = placeholder.count(_ == '?') + if (questionMarkCount == 0) { + throw new IllegalArgumentException("Placeholder must contain one '?'") + } + if (questionMarkCount > 1) { + throw new IllegalArgumentException("Placeholder can contain only one '?'") + } + new JdbcEncoder[A] { + + override def encode(value: A): SqlFragment = + SqlFragment(Chunk(SqlFragment.Segment.Syntax.apply(placeholder))) + + override def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Int = { + ps.setString(index, valueToString(value)) + index + 1 + } + + override def unsafeSetNull(ps: PreparedStatement, index: Int): Int = { + ps.setNull(index, Types.VARCHAR) + index + 1 + } + + override def sql(value: A): String = placeholder + + override def prettyValuePrinter(value: A): String = valueToString(value) + } + } + def apply[A]()(implicit encoder: JdbcEncoder[A]): JdbcEncoder[A] = encoder - implicit val intEncoder: JdbcEncoder[Int] = value => sql"$value" - implicit val longEncoder: JdbcEncoder[Long] = value => sql"$value" - implicit val doubleEncoder: JdbcEncoder[Double] = value => sql"$value" - implicit val charEncoder: JdbcEncoder[Char] = value => sql"$value" - implicit val stringEncoder: JdbcEncoder[String] = value => sql"$value" - implicit val booleanEncoder: JdbcEncoder[Boolean] = value => sql"$value" - implicit val bigIntEncoder: JdbcEncoder[java.math.BigInteger] = value => sql"$value" - implicit val bigDecimalEncoder: JdbcEncoder[java.math.BigDecimal] = value => sql"$value" - implicit val bigDecimalEncoderScala: JdbcEncoder[scala.math.BigDecimal] = value => sql"$value" - implicit val shortEncoder: JdbcEncoder[Short] = value => sql"$value" - implicit val floatEncoder: JdbcEncoder[Float] = value => sql"$value" - implicit val byteEncoder: JdbcEncoder[Byte] = value => sql"$value" - implicit val byteArrayEncoder: JdbcEncoder[Array[Byte]] = value => sql"$value" - implicit val byteChunkEncoder: JdbcEncoder[Chunk[Byte]] = value => sql"$value" - implicit val blobEncoder: JdbcEncoder[java.sql.Blob] = value => sql"$value" - implicit val uuidEncoder: JdbcEncoder[java.util.UUID] = value => sql"$value" - - implicit def singleParamEncoder[A: SqlFragment.Setter]: JdbcEncoder[A] = value => sql"$value" + def apply[A]( + onEncode: A => SqlFragment, + onSql: => String, + onValue: (PreparedStatement, Int, A) => Unit, + onNull: (PreparedStatement, Int) => Unit + ): JdbcEncoder[A] = + new JdbcEncoder[A] { - // TODO: review for cases like Option of a tuple - def optionEncoder[A](implicit encoder: JdbcEncoder[A]): JdbcEncoder[Option[A]] = - value => value.fold(SqlFragment.nullLiteral)(encoder.encode) + override def encode(value: A): SqlFragment = onEncode(value) - implicit def tuple2Encoder[A: JdbcEncoder, B: JdbcEncoder]: JdbcEncoder[(A, B)] = - tuple => JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode(tuple._2) + override def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Int = { + onValue(ps, index, value) + index + 1 + } - implicit def tuple3Encoder[A: JdbcEncoder, B: JdbcEncoder, C: JdbcEncoder]: JdbcEncoder[(A, B, C)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) + override def unsafeSetNull(ps: PreparedStatement, index: Int): Int = { + onNull(ps, index) + index + 1 + } - implicit def tuple4Encoder[A: JdbcEncoder, B: JdbcEncoder, C: JdbcEncoder, D: JdbcEncoder] - : JdbcEncoder[(A, B, C, D)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) + override def sql(a: A): String = onSql + } - implicit def tuple5Encoder[A: JdbcEncoder, B: JdbcEncoder, C: JdbcEncoder, D: JdbcEncoder, E: JdbcEncoder] - : JdbcEncoder[(A, B, C, D, E)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) - - implicit def tuple6Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) + private def forSqlType[A](onValue: (PreparedStatement, Int, A) => Unit, sqlType: Int): JdbcEncoder[A] = + new JdbcEncoder[A] { - implicit def tuple7Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) - - implicit def tuple8Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 + override def encode(value: A): SqlFragment = SqlFragment( + Chunk.apply(Segment.Param(value, this.asInstanceOf[JdbcEncoder[Any]])) ) - implicit def tuple9Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) - - implicit def tuple10Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) + def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Int = { + onValue(ps, index, value) + index + 1 + } - implicit def tuple11Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) - - implicit def tuple12Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) + def unsafeSetNull(ps: PreparedStatement, index: Int): Int = { + ps.setNull(index, sqlType) + index + 1 + } - implicit def tuple13Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) - - implicit def tuple14Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) + override def sql(a: A): String = "?" + } - implicit def tuple15Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder, - O: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) - - implicit def tuple16Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder, - O: JdbcEncoder, - P: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) ++ SqlFragment.comma ++ JdbcEncoder[P]().encode( - tuple._16 - ) + private def forIterableSqlType[A, I]( + iterator: I => Iterator[A], + sqlType: Int + )(implicit encoder: JdbcEncoder[A]): JdbcEncoder[I] = + new JdbcEncoder[I] { - implicit def tuple17Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder, - O: JdbcEncoder, - P: JdbcEncoder, - Q: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) ++ SqlFragment.comma ++ JdbcEncoder[P]().encode( - tuple._16 - ) ++ SqlFragment.comma ++ JdbcEncoder[Q]().encode(tuple._17) - - implicit def tuple18Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder, - O: JdbcEncoder, - P: JdbcEncoder, - Q: JdbcEncoder, - R: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) ++ SqlFragment.comma ++ JdbcEncoder[P]().encode( - tuple._16 - ) ++ SqlFragment.comma ++ JdbcEncoder[Q]().encode(tuple._17) ++ SqlFragment.comma ++ JdbcEncoder[R]().encode( - tuple._18 + override def encode(value: I): SqlFragment = SqlFragment( + Chunk.apply(Segment.Param(value, this.asInstanceOf[JdbcEncoder[Any]])) ) - implicit def tuple19Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder, - O: JdbcEncoder, - P: JdbcEncoder, - Q: JdbcEncoder, - R: JdbcEncoder, - S: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) ++ SqlFragment.comma ++ JdbcEncoder[P]().encode( - tuple._16 - ) ++ SqlFragment.comma ++ JdbcEncoder[Q]().encode(tuple._17) ++ SqlFragment.comma ++ JdbcEncoder[R]().encode( - tuple._18 - ) ++ SqlFragment.comma ++ JdbcEncoder[S]().encode(tuple._19) - - implicit def tuple20Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder, - O: JdbcEncoder, - P: JdbcEncoder, - Q: JdbcEncoder, - R: JdbcEncoder, - S: JdbcEncoder, - T: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) ++ SqlFragment.comma ++ JdbcEncoder[P]().encode( - tuple._16 - ) ++ SqlFragment.comma ++ JdbcEncoder[Q]().encode(tuple._17) ++ SqlFragment.comma ++ JdbcEncoder[R]().encode( - tuple._18 - ) ++ SqlFragment.comma ++ JdbcEncoder[S]().encode(tuple._19) ++ SqlFragment.comma ++ JdbcEncoder[T]().encode( - tuple._20 - ) + def unsafeSetValue(ps: PreparedStatement, index: Int, value: I): Int = + iterator(value) + .foldLeft(index) { case (i, value) => encoder.unsafeSetValue(ps, i, value) } - implicit def tuple21Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder, - O: JdbcEncoder, - P: JdbcEncoder, - Q: JdbcEncoder, - R: JdbcEncoder, - S: JdbcEncoder, - T: JdbcEncoder, - U: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) ++ SqlFragment.comma ++ JdbcEncoder[P]().encode( - tuple._16 - ) ++ SqlFragment.comma ++ JdbcEncoder[Q]().encode(tuple._17) ++ SqlFragment.comma ++ JdbcEncoder[R]().encode( - tuple._18 - ) ++ SqlFragment.comma ++ JdbcEncoder[S]().encode(tuple._19) ++ SqlFragment.comma ++ JdbcEncoder[T]().encode( - tuple._20 - ) ++ SqlFragment.comma ++ JdbcEncoder[U]().encode(tuple._21) - - implicit def tuple22Encoder[ - A: JdbcEncoder, - B: JdbcEncoder, - C: JdbcEncoder, - D: JdbcEncoder, - E: JdbcEncoder, - F: JdbcEncoder, - G: JdbcEncoder, - H: JdbcEncoder, - I: JdbcEncoder, - J: JdbcEncoder, - K: JdbcEncoder, - L: JdbcEncoder, - M: JdbcEncoder, - N: JdbcEncoder, - O: JdbcEncoder, - P: JdbcEncoder, - Q: JdbcEncoder, - R: JdbcEncoder, - S: JdbcEncoder, - T: JdbcEncoder, - U: JdbcEncoder, - V: JdbcEncoder - ]: JdbcEncoder[(A, B, C, D, E, F, G, H, I, J, K, L, M, N, O, P, Q, R, S, T, U, V)] = - tuple => - JdbcEncoder[A]().encode(tuple._1) ++ SqlFragment.comma ++ JdbcEncoder[B]().encode( - tuple._2 - ) ++ SqlFragment.comma ++ JdbcEncoder[C]().encode(tuple._3) ++ SqlFragment.comma ++ JdbcEncoder[D]().encode( - tuple._4 - ) ++ SqlFragment.comma ++ JdbcEncoder[E]().encode(tuple._5) ++ SqlFragment.comma ++ JdbcEncoder[F]().encode( - tuple._6 - ) ++ SqlFragment.comma ++ JdbcEncoder[G]().encode(tuple._7) ++ SqlFragment.comma ++ JdbcEncoder[H]().encode( - tuple._8 - ) ++ SqlFragment.comma ++ JdbcEncoder[I]().encode(tuple._9) ++ SqlFragment.comma ++ JdbcEncoder[J]().encode( - tuple._10 - ) ++ SqlFragment.comma ++ JdbcEncoder[K]().encode(tuple._11) ++ SqlFragment.comma ++ JdbcEncoder[L]().encode( - tuple._12 - ) ++ SqlFragment.comma ++ JdbcEncoder[M]().encode(tuple._13) ++ SqlFragment.comma ++ JdbcEncoder[N]().encode( - tuple._14 - ) ++ SqlFragment.comma ++ JdbcEncoder[O]().encode(tuple._15) ++ SqlFragment.comma ++ JdbcEncoder[P]().encode( - tuple._16 - ) ++ SqlFragment.comma ++ JdbcEncoder[Q]().encode(tuple._17) ++ SqlFragment.comma ++ JdbcEncoder[R]().encode( - tuple._18 - ) ++ SqlFragment.comma ++ JdbcEncoder[S]().encode(tuple._19) ++ SqlFragment.comma ++ JdbcEncoder[T]().encode( - tuple._20 - ) ++ SqlFragment.comma ++ JdbcEncoder[U]().encode(tuple._21) ++ SqlFragment.comma ++ JdbcEncoder[V]().encode( - tuple._22 - ) + def unsafeSetNull(ps: PreparedStatement, index: Int): Int = { + ps.setNull(index, sqlType) + index + 1 + } + + override def sql(a: I): String = iterator(a).map(_ => "?").mkString(",") + + override def prettyValuePrinter(value: I): String = iterator(value) + .mkString(", ") + } + + def other[A](onValue: (PreparedStatement, Int, A) => Unit, sqlType: String): JdbcEncoder[A] = new JdbcEncoder[A] { + + override def encode(value: A): SqlFragment = SqlFragment( + Chunk.apply(Segment.Param(value, this.asInstanceOf[JdbcEncoder[Any]])) + ) + + def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Int = { + onValue(ps, index, value) + index + 1 + } + + def unsafeSetNull(ps: PreparedStatement, index: Int): Int = { + ps.setNull(index, Types.OTHER, sqlType) + index + 1 + } + + override def sql(a: A): String = "?" + } + + implicit def optionParamEncoder[A](implicit encoder: JdbcEncoder[A]): JdbcEncoder[Option[A]] = + JdbcEncoder( + _.fold(SqlFragment.nullLiteral)(encoder.encode), + "?", + (ps, i, value) => + value match { + case Some(value) => encoder.unsafeSetValue(ps, i, value) + case None => encoder.unsafeSetNull(ps, i) + }, + (ps, i) => encoder.unsafeSetNull(ps, i) + ) + + implicit val intEncoder: JdbcEncoder[Int] = forSqlType((ps, i, value) => ps.setInt(i, value), Types.INTEGER) + implicit val longEncoder: JdbcEncoder[Long] = forSqlType((ps, i, value) => ps.setLong(i, value), Types.BIGINT) + implicit val doubleEncoder: JdbcEncoder[Double] = forSqlType((ps, i, value) => ps.setDouble(i, value), Types.DOUBLE) + implicit val stringEncoder: JdbcEncoder[String] = forSqlType((ps, i, value) => ps.setString(i, value), Types.VARCHAR) + implicit val booleanEncoder: JdbcEncoder[Boolean] = + forSqlType((ps, i, value) => ps.setBoolean(i, value), Types.BOOLEAN) + implicit val shortEncoder: JdbcEncoder[Short] = forSqlType((ps, i, value) => ps.setShort(i, value), Types.SMALLINT) + implicit val floatEncoder: JdbcEncoder[Float] = forSqlType((ps, i, value) => ps.setFloat(i, value), Types.FLOAT) + implicit val byteEncoder: JdbcEncoder[Byte] = forSqlType((ps, i, value) => ps.setByte(i, value), Types.TINYINT) + implicit val byteArrayEncoder: JdbcEncoder[Array[Byte]] = + forSqlType((ps, i, value) => ps.setBytes(i, value), Types.ARRAY) + implicit val blobEncoder: JdbcEncoder[java.sql.Blob] = forSqlType((ps, i, value) => ps.setBlob(i, value), Types.BLOB) + implicit val sqlDateEncoder: JdbcEncoder[java.sql.Date] = + forSqlType((ps, i, value) => ps.setDate(i, value), Types.DATE) + implicit val sqlTimeEncoder: JdbcEncoder[java.sql.Time] = + forSqlType((ps, i, value) => ps.setTime(i, value), Types.TIME) + + implicit def chunkEncoder[A](implicit encoder: JdbcEncoder[A]): JdbcEncoder[Chunk[A]] = iterableEncoder[A, Chunk[A]] + + implicit def listEncoder[A](implicit encoder: JdbcEncoder[A]): JdbcEncoder[List[A]] = iterableEncoder[A, List[A]] + + implicit def vectorEncoder[A](implicit encoder: JdbcEncoder[A]): JdbcEncoder[Vector[A]] = + iterableEncoder[A, Vector[A]] + + implicit def setEncoder[A](implicit encoder: JdbcEncoder[A]): JdbcEncoder[Set[A]] = iterableEncoder[A, Set[A]] + + implicit def arrayEncoder[A](implicit encoder: JdbcEncoder[A]): JdbcEncoder[Array[A]] = + forIterableSqlType( + _.iterator, + Types.OTHER + ) + + private def iterableEncoder[A, I <: Iterable[A]](implicit encoder: JdbcEncoder[A]): JdbcEncoder[I] = + forIterableSqlType( + _.iterator, + Types.OTHER + ) + + implicit val bigDecimalEncoder: JdbcEncoder[java.math.BigDecimal] = + forSqlType((ps, i, value) => ps.setBigDecimal(i, value), Types.NUMERIC) + implicit val sqlTimestampEncoder: JdbcEncoder[java.sql.Timestamp] = + forSqlType((ps, i, value) => ps.setTimestamp(i, value), Types.TIMESTAMP) + + implicit val uuidParamEncoder: JdbcEncoder[java.util.UUID] = other((ps, i, value) => ps.setObject(i, value), "uuid") + + implicit val charEncoder: JdbcEncoder[Char] = stringEncoder.contramap(_.toString) + implicit val bigIntEncoder: JdbcEncoder[java.math.BigInteger] = + bigDecimalEncoder.contramap(new java.math.BigDecimal(_)) + implicit val bigDecimalScalaEncoder: JdbcEncoder[scala.math.BigDecimal] = bigDecimalEncoder.contramap(_.bigDecimal) + implicit val byteChunkEncoder: JdbcEncoder[Chunk[Byte]] = byteArrayEncoder.contramap(_.toArray) + implicit val instantEncoder: JdbcEncoder[java.time.Instant] = sqlTimestampEncoder.contramap(java.sql.Timestamp.from) + + // TODO: review for cases like Option of a tuple + implicit def tuple2Encoder[A: JdbcEncoder, B: JdbcEncoder]: JdbcEncoder[(A, B)] = + tupleNEncoder(JdbcEncoder[A](), JdbcEncoder[B]()) + + implicit def tuple3Encoder[A: JdbcEncoder, B: JdbcEncoder, C: JdbcEncoder]: JdbcEncoder[(A, B, C)] = + tupleNEncoder(JdbcEncoder[A](), JdbcEncoder[B](), JdbcEncoder[C]()) + implicit def tuple4Encoder[A: JdbcEncoder, B: JdbcEncoder, C: JdbcEncoder, D: JdbcEncoder] + : JdbcEncoder[(A, B, C, D)] = + tupleNEncoder(JdbcEncoder[A](), JdbcEncoder[B](), JdbcEncoder[C](), JdbcEncoder[D]()) + + private def tupleNEncoder[A <: Product](encoders: JdbcEncoder[_]*): JdbcEncoder[A] = + new JdbcEncoder[A] { + override def encode(value: A): SqlFragment = + SqlFragment.intersperse( + SqlFragment.comma, + value.productIterator + .zip(encoders) + .map { case (value, encoder) => encoder.asInstanceOf[JdbcEncoder[Any]].encode(value) } + .toSeq + ) + + override def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Int = + value.productIterator + .zip(encoders) + .foldLeft(index) { case (i, (value, encoder)) => + encoder.asInstanceOf[JdbcEncoder[Any]].unsafeSetValue(ps, i, value) + } + + override def unsafeSetNull(ps: PreparedStatement, index: Int): Int = + encoders + .foldLeft(index) { case (i, encoder) => + encoder.asInstanceOf[JdbcEncoder[Any]].unsafeSetNull(ps, i) + } + + override def sql(a: A): String = Seq.fill(encoders.size)("?").mkString(", ") + + override def prettyValuePrinter(value: A): String = value.productIterator.mkString(", ") + } } trait JdbcEncoder0LowPriorityImplicits { self => @@ -649,7 +297,7 @@ trait JdbcEncoder0LowPriorityImplicits { self => case StandardType.BigIntegerType => JdbcEncoder.bigIntEncoder case StandardType.BinaryType => JdbcEncoder.byteChunkEncoder case StandardType.BigDecimalType => JdbcEncoder.bigDecimalEncoder - case StandardType.UUIDType => JdbcEncoder.uuidEncoder + case StandardType.UUIDType => JdbcEncoder.uuidParamEncoder // TODO: Standard Types which are missing are the date time types, not sure what would be the best way to handle them case _ => throw JdbcEncoderError(s"Unsupported type: $standardType", new IllegalArgumentException) } @@ -660,7 +308,7 @@ trait JdbcEncoder0LowPriorityImplicits { self => case Schema.Primitive(standardType, _) => primitiveCodec(standardType) case Schema.Optional(schema, _) => - JdbcEncoder.optionEncoder(self.fromSchema(schema)) + JdbcEncoder.optionParamEncoder(self.fromSchema(schema)) case Schema.Tuple2(left, right, _) => JdbcEncoder.tuple2Encoder(self.fromSchema(left), self.fromSchema(right)) // format: off @@ -694,10 +342,26 @@ trait JdbcEncoder0LowPriorityImplicits { self => throw JdbcEncoderError(s"Failed to encode schema ${schema}", new IllegalArgumentException) } - private[jdbc] def caseClassEncoder[A](fields: Chunk[Schema.Field[A, _]]): JdbcEncoder[A] = { (a: A) => - fields.map { f => - val encoder = self.fromSchema(f.schema.asInstanceOf[Schema[Any]]) - encoder.encode(f.get(a)) - }.reduce(_ ++ SqlFragment.comma ++ _) + private[jdbc] def caseClassEncoder[A](fields: Chunk[Schema.Field[A, _]]): JdbcEncoder[A] = new JdbcEncoder[A] { + override def encode(a: A): SqlFragment = + fields.map { f => + val encoder = self.fromSchema(f.schema.asInstanceOf[Schema[Any]]) + encoder.encode(f.get(a)) + }.reduce(_ ++ SqlFragment.comma ++ _) + + override def unsafeSetValue(ps: PreparedStatement, index: Int, value: A): Int = + fields.foldLeft(index) { case (i, f) => + val encoder = self.fromSchema(f.schema.asInstanceOf[Schema[Any]]) + encoder.unsafeSetValue(ps, i, f.get(value)) + } + + override def unsafeSetNull(ps: PreparedStatement, index: Int): Int = + fields.foldLeft(index) { case (i, f) => + val encoder = self.fromSchema(f.schema.asInstanceOf[Schema[Any]]) + encoder.unsafeSetNull(ps, i) + } + + override def sql(a: A): String = Seq.fill(fields.iterator.size)("?").mkString(",") } + } diff --git a/core/src/main/scala/zio/jdbc/SqlFragment.scala b/core/src/main/scala/zio/jdbc/SqlFragment.scala index 91346c6f..b646e002 100644 --- a/core/src/main/scala/zio/jdbc/SqlFragment.scala +++ b/core/src/main/scala/zio/jdbc/SqlFragment.scala @@ -18,7 +18,6 @@ package zio.jdbc import zio._ import zio.jdbc.SqlFragment.Segment -import java.sql.{ PreparedStatement, Types } import scala.language.implicitConversions /** @@ -84,6 +83,8 @@ sealed trait SqlFragment { self => def notIn[B](b: B, bs: B*)(implicit encoder: JdbcEncoder[B]): SqlFragment = notIn(b +: bs) + def and(): SqlFragment = self ++ SqlFragment.and + def notIn[B](bs: Iterable[B])(implicit encoder: JdbcEncoder[B]): SqlFragment = in0(SqlFragment.notIn, bs) @@ -112,27 +113,8 @@ sealed trait SqlFragment { self => foreachSegment { syntax => sql.append(syntax.value) } { param => - param.value match { - case iterable: Iterable[_] => - iterable.iterator.foreach { item => - paramsBuilder += item.toString - } - sql.append( - Seq.fill(iterable.iterator.size)("?").mkString(",") - ) - - case array: Array[_] => - array.foreach { item => - paramsBuilder += item.toString - } - sql.append( - Seq.fill(array.length)("?").mkString(",") - ) - - case _ => - sql.append("?") - paramsBuilder += param.value.toString - } + sql.append(param.setter.sql(param.value)) + paramsBuilder += param.setter.prettyValuePrinter(param.value) } val params = paramsBuilder.result() @@ -246,101 +228,17 @@ object SqlFragment { sealed trait Segment object Segment { - final case class Syntax(value: String) extends Segment - final case class Param(value: Any, setter: Setter[Any]) extends Segment - final case class Nested(sql: SqlFragment) extends Segment - - implicit def paramSegment[A](a: A)(implicit setter: Setter[A]): Segment.Param = - Segment.Param(a, setter.asInstanceOf[Setter[Any]]) - - implicit def nestedSqlSegment[A](sql: SqlFragment): Segment.Nested = Segment.Nested(sql) - } + final case class Syntax(value: String) extends Segment + final case class Param(value: Any, setter: JdbcEncoder[Any]) extends Segment + final case class Nested(sql: SqlFragment) extends Segment - trait Setter[A] { self => - def setValue(ps: PreparedStatement, index: Int, value: A): Unit - def setNull(ps: PreparedStatement, index: Int): Unit + implicit def paramSegment[A](a: A)(implicit encoder: JdbcEncoder[A]): Segment.Param = + Segment.Param(a, encoder.asInstanceOf[JdbcEncoder[Any]]) - final def contramap[B](f: B => A): Setter[B] = - Setter((ps, i, value) => self.setValue(ps, i, f(value)), (ps, i) => self.setNull(ps, i)) - } - - object Setter { - def apply[A]()(implicit setter: Setter[A]): Setter[A] = setter - - def apply[A](onValue: (PreparedStatement, Int, A) => Unit, onNull: (PreparedStatement, Int) => Unit): Setter[A] = - new Setter[A] { - def setValue(ps: PreparedStatement, index: Int, value: A): Unit = onValue(ps, index, value) - def setNull(ps: PreparedStatement, index: Int): Unit = onNull(ps, index) - } - - def forSqlType[A](onValue: (PreparedStatement, Int, A) => Unit, sqlType: Int): Setter[A] = new Setter[A] { - def setValue(ps: PreparedStatement, index: Int, value: A): Unit = onValue(ps, index, value) - def setNull(ps: PreparedStatement, index: Int): Unit = ps.setNull(index, sqlType) - } +// implicit def paramSegment[A](a: A)(implicit setter: JdbcEncoder[A]): Segment = +// Nested(setter.encode(a)) - def other[A](onValue: (PreparedStatement, Int, A) => Unit, sqlType: String): Setter[A] = new Setter[A] { - def setValue(ps: PreparedStatement, index: Int, value: A): Unit = onValue(ps, index, value) - def setNull(ps: PreparedStatement, index: Int): Unit = ps.setNull(index, Types.OTHER, sqlType) - } - - implicit def optionParamSetter[A](implicit setter: Setter[A]): Setter[Option[A]] = - Setter( - (ps, i, value) => - value match { - case Some(value) => setter.setValue(ps, i, value) - case None => setter.setNull(ps, i) - }, - (ps, i) => setter.setNull(ps, i) - ) - - implicit val intSetter: Setter[Int] = forSqlType((ps, i, value) => ps.setInt(i, value), Types.INTEGER) - implicit val longSetter: Setter[Long] = forSqlType((ps, i, value) => ps.setLong(i, value), Types.BIGINT) - implicit val doubleSetter: Setter[Double] = forSqlType((ps, i, value) => ps.setDouble(i, value), Types.DOUBLE) - implicit val stringSetter: Setter[String] = forSqlType((ps, i, value) => ps.setString(i, value), Types.VARCHAR) - implicit val booleanSetter: Setter[Boolean] = forSqlType((ps, i, value) => ps.setBoolean(i, value), Types.BOOLEAN) - implicit val shortSetter: Setter[Short] = forSqlType((ps, i, value) => ps.setShort(i, value), Types.SMALLINT) - implicit val floatSetter: Setter[Float] = forSqlType((ps, i, value) => ps.setFloat(i, value), Types.FLOAT) - implicit val byteSetter: Setter[Byte] = forSqlType((ps, i, value) => ps.setByte(i, value), Types.TINYINT) - implicit val byteArraySetter: Setter[Array[Byte]] = forSqlType((ps, i, value) => ps.setBytes(i, value), Types.ARRAY) - implicit val blobSetter: Setter[java.sql.Blob] = forSqlType((ps, i, value) => ps.setBlob(i, value), Types.BLOB) - implicit val sqlDateSetter: Setter[java.sql.Date] = forSqlType((ps, i, value) => ps.setDate(i, value), Types.DATE) - implicit val sqlTimeSetter: Setter[java.sql.Time] = forSqlType((ps, i, value) => ps.setTime(i, value), Types.TIME) - - implicit def chunkSetter[A](implicit setter: Setter[A]): Setter[Chunk[A]] = iterableSetter[A, Chunk[A]] - implicit def listSetter[A](implicit setter: Setter[A]): Setter[List[A]] = iterableSetter[A, List[A]] - implicit def vectorSetter[A](implicit setter: Setter[A]): Setter[Vector[A]] = iterableSetter[A, Vector[A]] - implicit def setSetter[A](implicit setter: Setter[A]): Setter[Set[A]] = iterableSetter[A, Set[A]] - - implicit def arraySetter[A](implicit setter: Setter[A]): Setter[Array[A]] = - forSqlType( - (ps, i, iterable) => - iterable.zipWithIndex.foreach { case (value, valueIdx) => - setter.setValue(ps, i + valueIdx, value) - }, - Types.OTHER - ) - - private def iterableSetter[A, I <: Iterable[A]](implicit setter: Setter[A]): Setter[I] = - forSqlType( - (ps, i, iterable) => - iterable.zipWithIndex.foreach { case (value, valueIdx) => - setter.setValue(ps, i + valueIdx, value) - }, - Types.OTHER - ) - - implicit val bigDecimalSetter: Setter[java.math.BigDecimal] = - forSqlType((ps, i, value) => ps.setBigDecimal(i, value), Types.NUMERIC) - implicit val sqlTimestampSetter: Setter[java.sql.Timestamp] = - forSqlType((ps, i, value) => ps.setTimestamp(i, value), Types.TIMESTAMP) - - implicit val uuidParamSetter: Setter[java.util.UUID] = other((ps, i, value) => ps.setObject(i, value), "uuid") - - implicit val charSetter: Setter[Char] = stringSetter.contramap(_.toString) - implicit val bigIntSetter: Setter[java.math.BigInteger] = bigDecimalSetter.contramap(new java.math.BigDecimal(_)) - implicit val bigDecimalScalaSetter: Setter[scala.math.BigDecimal] = bigDecimalSetter.contramap(_.bigDecimal) - implicit val byteChunkSetter: Setter[Chunk[Byte]] = byteArraySetter.contramap(_.toArray) - implicit val instantSetter: Setter[java.time.Instant] = sqlTimestampSetter.contramap(java.sql.Timestamp.from) + implicit def nestedSqlSegment[A](sql: SqlFragment): Segment.Nested = Segment.Nested(sql) } def apply(sql: String): SqlFragment = sql @@ -360,7 +258,7 @@ object SqlFragment { def update(table: String): SqlFragment = s"UPDATE $table" - private[jdbc] def intersperse( + def intersperse( sep: SqlFragment, elements: Iterable[SqlFragment] ): SqlFragment = { diff --git a/core/src/main/scala/zio/jdbc/ZConnection.scala b/core/src/main/scala/zio/jdbc/ZConnection.scala index 2114bf9d..2a674c81 100644 --- a/core/src/main/scala/zio/jdbc/ZConnection.scala +++ b/core/src/main/scala/zio/jdbc/ZConnection.scala @@ -45,19 +45,7 @@ final class ZConnection(private[jdbc] val underlying: ZConnection.Restorable) ex statement <- ZIO.acquireRelease(ZIO.attempt { val sb = new StringBuilder() sql.foreachSegment(syntax => sb.append(syntax.value)) { param => - param.value match { - case iterable: Iterable[_] => - sb.append( - Seq.fill(iterable.iterator.size)("?").mkString(", ") - ) - - case array: Array[_] => - sb.append( - Seq.fill(array.length)("?").mkString(", ") - ) - - case _ => sb.append("?") - } + sb.append(param.setter.sql(param.value)) } transactionIsolationLevel.foreach { transactionIsolationLevel => connection.setTransactionIsolation(transactionIsolationLevel.toInt) @@ -67,8 +55,7 @@ final class ZConnection(private[jdbc] val underlying: ZConnection.Restorable) ex _ <- ZIO.attempt { var paramIndex = 1 sql.foreachSegment(_ => ()) { param => - param.setter.setValue(statement, paramIndex, param.value) - paramIndex += 1 + paramIndex = param.setter.unsafeSetValue(statement, paramIndex, param.value) } } result <- f(statement) diff --git a/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala b/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala index a42a7271..2ee9d889 100644 --- a/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala +++ b/core/src/test/scala/zio/jdbc/SqlFragmentSpec.scala @@ -1,13 +1,13 @@ package zio.jdbc import zio.Chunk -import zio.jdbc.SqlFragment.Setter -import zio.jdbc.{ transaction => transact } -import zio.schema.{ Schema, TypeId } +import zio.jdbc.{transaction => transact} +import zio.schema.{Schema, TypeId} import zio.test.Assertion._ import zio.test._ import java.sql.SQLException +import java.util.Base64 final case class Person(name: String, age: Int) final case class UserLogin(username: String, password: String) @@ -47,14 +47,28 @@ object SqlFragmentSpec extends ZIOSpecDefault { } + test("type safe interpolation") { final case class Foo(value: String) - implicit val fooParamSetter: SqlFragment.Setter[Foo] = SqlFragment.Setter[String]().contramap(_.toString) + implicit val fooParamJdbcEncoder: JdbcEncoder[Foo] = JdbcEncoder[String]().contramap(_.toString) val testSql = sql"${Foo("test")}" assertTrue(testSql.segments.collect { case SqlFragment.Segment.Param(_, setter) => setter - }.head eq fooParamSetter) && + }.head eq fooParamJdbcEncoder) && assertTrue(testSql.toString == "Sql(?, Foo(test))") + } + test("Custom JdbcEncoder") { + + implicit val encoder: JdbcEncoder[Array[Byte]] = JdbcEncoder.single( + s"FROM_BASE64('?')", + value => Base64.getEncoder.encodeToString(value) + ) + + val bytes = Array[Byte](1, 2, 3) + + val result = sql"UPDATE foo SET bytes = $bytes" + + assertTrue( + result.toString == "Sql(UPDATE foo SET bytes = FROM_BASE64('?'), AQID)" + ) } + suite(" SqlFragment.ParamSetter instances") { // TODO figure out how to test at PrepareStatement level test("Option") { @@ -142,15 +156,24 @@ object SqlFragmentSpec extends ZIOSpecDefault { .toString == "Sql(select name, age from users where id IN (?,?,?), 1, 2, 3)" ) + } + test("fragment method where with multiple iterator params") { + val seq = Seq(1, 2, 3) + assertTrue( + sql"select name, age from users where id" + .in(seq) + .and() + .notIn(seq) + .toString == + "Sql(select name, age from users where id IN (?,?,?) AND NOT IN (?,?,?), 1, 2, 3, 1, 2, 3)" + ) } + test("interpolation param is supported collection") { - def assertIn[A: Setter](collection: A) = { + def assertIn[A: JdbcEncoder](collection: A) = { println(sql"select name, age from users where id in ($collection)".toString) assertTrue( sql"select name, age from users where id in ($collection)".toString == "Sql(select name, age from users where id in (?,?,?), 1, 2, 3)" ) } - assertIn(Chunk(1, 2, 3)) && assertIn(List(1, 2, 3)) && assertIn(Vector(1, 2, 3)) && diff --git a/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala b/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala index fe364ac6..6f82c661 100644 --- a/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala +++ b/core/src/test/scala/zio/jdbc/ZConnectionPoolSpec.scala @@ -1,7 +1,6 @@ package zio.jdbc -import zio._ -import zio.jdbc.SqlFragment.Setter +import zio.{ jdbc, _ } import zio.schema._ import zio.test.Assertion._ import zio.test.TestAspect._ @@ -127,11 +126,8 @@ object ZConnectionPoolSpec extends ZIOSpecDefault { implicit val jdbcDecoder: JdbcDecoder[User] = JdbcDecoder[(String, Int)]().map[User](t => User(t._1, t._2)) - implicit val jdbcEncoder: JdbcEncoder[User] = (value: User) => { - val name = value.name - val age = value.age - sql"""${name}""" ++ ", " ++ s"${age}" - } + implicit val jdbcEncoder: JdbcEncoder[User] = + JdbcEncoder[(String, Int)]().contramap[User](t => (t.name, t.age)) } final case class UserNoId(name: String, age: Int) @@ -140,11 +136,9 @@ object ZConnectionPoolSpec extends ZIOSpecDefault { implicit val jdbcDecoder: JdbcDecoder[UserNoId] = JdbcDecoder[(String, Int)]().map[UserNoId](t => UserNoId(t._1, t._2)) - implicit val jdbcEncoder: JdbcEncoder[UserNoId] = (value: UserNoId) => { - val name = value.name - val age = value.age - sql"""${name}""" ++ ", " ++ s"${age}" - } + implicit val jdbcEncoder: JdbcEncoder[UserNoId] = + JdbcEncoder[(String, Int)]().contramap[UserNoId](t => (t.name, t.age)) + } def spec: Spec[TestEnvironment, Any] = @@ -298,7 +292,7 @@ object ZConnectionPoolSpec extends ZIOSpecDefault { test("select all in") { val namesToSearch = Chunk(sherlockHolmes.name, johnDoe.name) - def assertUsersFound[A: Setter](collection: A) = + def assertUsersFound[A: JdbcEncoder](collection: A) = for { users <- transaction { sql"select name, age from users where name IN ($collection)".query[User].selectAll @@ -314,6 +308,32 @@ object ZConnectionPoolSpec extends ZIOSpecDefault { assertUsersFound(namesToSearch.toSet) && assertUsersFound(namesToSearch.toArray) + for { + _ <- createUsers *> insertSherlock *> insertWatson *> insertJohn + testResult <- asserttions + } yield testResult + } + test("select all in multiple lists") { + val namesToSearch = Chunk(sherlockHolmes.name, johnDoe.name, johnWatson.name) + val namesToAvoid = Chunk(johnWatson.name) + + def assertUsersFound[A: JdbcEncoder](collection: A, collection2: A) = + for { + users <- transaction { + sql"select name, age from users where name IN ($collection) and name NOT IN ($collection2)" + .query[User] + .selectAll + } + } yield assertTrue( + users.map(_.name) == Chunk(sherlockHolmes.name, johnDoe.name) + ) + + def asserttions = + assertUsersFound(namesToSearch, namesToAvoid) && + assertUsersFound(namesToSearch.toList, namesToAvoid.toList) && + assertUsersFound(namesToSearch.toVector, namesToAvoid.toVector) && + assertUsersFound(namesToSearch.toSet, namesToAvoid.toSet) && + assertUsersFound(namesToSearch.toArray, namesToAvoid.toArray) + for { _ <- createUsers *> insertSherlock *> insertWatson *> insertJohn testResult <- asserttions @@ -338,6 +358,18 @@ object ZConnectionPoolSpec extends ZIOSpecDefault { _ <- createUsers *> insertSherlock num <- transaction(sql"update users set age = 43 where name = ${sherlockHolmes.name}".update) } yield assertTrue(num == 1L) + } + test("select with custom encoder") { + implicit val encoder: JdbcEncoder[Array[Char]] = JdbcEncoder.single( + s"TRIM(?)", + value => new String(value) + ) + + val sherlockArray = " Sherlock Holmes ".toCharArray + + for { + _ <- createUsers *> insertSherlock + num <- transaction(sql"delete from users where name = ${sherlockArray}".delete) + } yield assertTrue(num == 1L) } } + suite("decoding") {