From 536221470174250810079dd7e4e0778eebb9530a Mon Sep 17 00:00:00 2001 From: augustnagro Date: Mon, 25 Nov 2024 06:12:55 -0800 Subject: [PATCH] Spec Refactor: * Spec.build method is removed. Specs are now fully declarative and the DbType is responsible for implementation. * Tests are improved and now shared between database specs insteads of duplicated. The increased coverage exposed some tangential issues which are also resolved in this MR (see below). Note: switched from munit assertEquals -> assert until fix for https://github.com/scalameta/munit/issues/855#issuecomment-2506243568 is released. Additional changes: * No longer need to handle null when using DbCodec.biMap, or implementing DbCodec.readSingle * When implementing DbCodec, new method readSingleOption must be defined * Support optional products in outer-join queries (see test OptionalProductTests) * Switched to latest scalafmt version to prevent OOM in OracleTests --- .scalafmt.conf | 2 +- build.sbt | 2 +- .../com/augustnagro/magnum/pg/PgCodec.scala | 95 +++- .../magnum/pg/json/JsonBDbCodec.scala | 10 +- .../magnum/pg/json/JsonDbCodec.scala | 10 +- .../magnum/pg/xml/XmlDbCodec.scala | 10 +- magnum-pg/src/test/scala/PgCodecTests.scala | 19 +- .../augustnagro/magnum/ClickhouseDbType.scala | 13 +- .../com/augustnagro/magnum/DbCodec.scala | 205 ++++++- .../scala/com/augustnagro/magnum/Frag.scala | 39 +- .../com/augustnagro/magnum/FragWriter.scala | 3 + .../com/augustnagro/magnum/H2DbType.scala | 12 +- .../com/augustnagro/magnum/MySqlDbType.scala | 81 +-- .../com/augustnagro/magnum/NullOrder.scala | 14 +- .../com/augustnagro/magnum/OracleDbType.scala | 25 +- .../augustnagro/magnum/PostgresDbType.scala | 13 +- .../scala/com/augustnagro/magnum/Query.scala | 2 +- .../com/augustnagro/magnum/Returning.scala | 61 +- .../scala/com/augustnagro/magnum/Seek.scala | 10 + .../com/augustnagro/magnum/SeekDir.scala | 10 +- .../scala/com/augustnagro/magnum/Sort.scala | 8 +- .../com/augustnagro/magnum/SortOrder.scala | 11 +- .../scala/com/augustnagro/magnum/Spec.scala | 72 +-- .../com/augustnagro/magnum/SpecImpl.scala | 88 +++ .../com/augustnagro/magnum/SqliteDbType.scala | 23 +- .../com/augustnagro/magnum/UUIDCodec.scala | 6 +- .../scala/com/augustnagro/magnum/Update.scala | 2 +- .../scala/com/augustnagro/magnum/util.scala | 20 +- magnum/src/test/resources/clickhouse-car.sql | 38 -- .../src/test/resources/clickhouse-person.sql | 78 --- .../src/test/resources/clickhouse/big-dec.sql | 12 + magnum/src/test/resources/clickhouse/car.sql | 17 + .../src/test/resources/clickhouse/no-id.sql | 14 + .../src/test/resources/clickhouse/person.sql | 22 + magnum/src/test/resources/h2-car.sql | 14 - magnum/src/test/resources/h2-person.sql | 20 - magnum/src/test/resources/h2/big-dec.sql | 10 + magnum/src/test/resources/h2/car.sql | 15 + magnum/src/test/resources/h2/my-user.sql | 11 + magnum/src/test/resources/h2/no-id.sql | 12 + magnum/src/test/resources/h2/person.sql | 20 + magnum/src/test/resources/mysql-car.sql | 14 - magnum/src/test/resources/mysql-person.sql | 20 - .../{pg-bigdec.sql => mysql/big-dec.sql} | 0 magnum/src/test/resources/mysql/car.sql | 15 + magnum/src/test/resources/mysql/my-user.sql | 11 + magnum/src/test/resources/mysql/no-id.sql | 12 + magnum/src/test/resources/mysql/person.sql | 20 + magnum/src/test/resources/pg-car.sql | 17 - magnum/src/test/resources/pg-person.sql | 20 - magnum/src/test/resources/pg/big-dec.sql | 10 + magnum/src/test/resources/pg/car.sql | 15 + magnum/src/test/resources/pg/my-user.sql | 11 + .../resources/{pg-no-id.sql => pg/no-id.sql} | 0 magnum/src/test/resources/pg/person.sql | 20 + magnum/src/test/scala/ClickHouseTests.scala | 413 +------------- magnum/src/test/scala/H2Tests.scala | 403 +------------- magnum/src/test/scala/MySqlTests.scala | 429 +------------- magnum/src/test/scala/OracleTests.scala | 504 +++-------------- magnum/src/test/scala/PgTests.scala | 522 +----------------- magnum/src/test/scala/SpecTests.scala | 116 ---- magnum/src/test/scala/SqliteTests.scala | 499 +++-------------- .../src/test/scala/shared/BigDecTests.scala | 22 + magnum/src/test/scala/shared/Color.scala | 6 + .../test/scala/shared/EmbeddedFragTests.scala | 24 + .../scala/shared/EntityCreatorTests.scala | 114 ++++ .../scala/shared/ImmutableRepoTests.scala | 149 +++++ magnum/src/test/scala/shared/NoIdTests.scala | 28 + .../scala/shared/OptionalProductTests.scala | 35 ++ magnum/src/test/scala/shared/RepoTests.scala | 389 +++++++++++++ .../src/test/scala/shared/SharedTests.scala | 24 + magnum/src/test/scala/shared/SpecTests.scala | 158 ++++++ .../src/test/scala/shared/SqlNameTests.scala | 29 + 73 files changed, 2091 insertions(+), 3107 deletions(-) create mode 100644 magnum/src/main/scala/com/augustnagro/magnum/Seek.scala create mode 100644 magnum/src/main/scala/com/augustnagro/magnum/SpecImpl.scala delete mode 100644 magnum/src/test/resources/clickhouse-car.sql delete mode 100644 magnum/src/test/resources/clickhouse-person.sql create mode 100644 magnum/src/test/resources/clickhouse/big-dec.sql create mode 100644 magnum/src/test/resources/clickhouse/car.sql create mode 100644 magnum/src/test/resources/clickhouse/no-id.sql create mode 100644 magnum/src/test/resources/clickhouse/person.sql delete mode 100644 magnum/src/test/resources/h2-car.sql delete mode 100644 magnum/src/test/resources/h2-person.sql create mode 100644 magnum/src/test/resources/h2/big-dec.sql create mode 100644 magnum/src/test/resources/h2/car.sql create mode 100644 magnum/src/test/resources/h2/my-user.sql create mode 100644 magnum/src/test/resources/h2/no-id.sql create mode 100644 magnum/src/test/resources/h2/person.sql delete mode 100644 magnum/src/test/resources/mysql-car.sql delete mode 100644 magnum/src/test/resources/mysql-person.sql rename magnum/src/test/resources/{pg-bigdec.sql => mysql/big-dec.sql} (100%) create mode 100644 magnum/src/test/resources/mysql/car.sql create mode 100644 magnum/src/test/resources/mysql/my-user.sql create mode 100644 magnum/src/test/resources/mysql/no-id.sql create mode 100644 magnum/src/test/resources/mysql/person.sql delete mode 100644 magnum/src/test/resources/pg-car.sql delete mode 100644 magnum/src/test/resources/pg-person.sql create mode 100644 magnum/src/test/resources/pg/big-dec.sql create mode 100644 magnum/src/test/resources/pg/car.sql create mode 100644 magnum/src/test/resources/pg/my-user.sql rename magnum/src/test/resources/{pg-no-id.sql => pg/no-id.sql} (100%) create mode 100644 magnum/src/test/resources/pg/person.sql delete mode 100644 magnum/src/test/scala/SpecTests.scala create mode 100644 magnum/src/test/scala/shared/BigDecTests.scala create mode 100644 magnum/src/test/scala/shared/Color.scala create mode 100644 magnum/src/test/scala/shared/EmbeddedFragTests.scala create mode 100644 magnum/src/test/scala/shared/EntityCreatorTests.scala create mode 100644 magnum/src/test/scala/shared/ImmutableRepoTests.scala create mode 100644 magnum/src/test/scala/shared/NoIdTests.scala create mode 100644 magnum/src/test/scala/shared/OptionalProductTests.scala create mode 100644 magnum/src/test/scala/shared/RepoTests.scala create mode 100644 magnum/src/test/scala/shared/SharedTests.scala create mode 100644 magnum/src/test/scala/shared/SpecTests.scala create mode 100644 magnum/src/test/scala/shared/SqlNameTests.scala diff --git a/.scalafmt.conf b/.scalafmt.conf index 6680f8e..1a44fdb 100644 --- a/.scalafmt.conf +++ b/.scalafmt.conf @@ -1,4 +1,4 @@ -version = 3.8.3 +version = 3.8.4-RC3 runner.dialect = scala3 rewrite.scala3.insertEndMarkerMinLines = 20 rewrite.scala3.removeEndMarkerMaxLines = 19 diff --git a/build.sbt b/build.sbt index da7bc8f..937f1fb 100644 --- a/build.sbt +++ b/build.sbt @@ -36,7 +36,7 @@ ThisBuild / publishTo := { } ThisBuild / publish / skip := true -Global / onChangedBuildSource := ReloadOnSourceChanges +addCommandAlias("fmt", "scalafmtAll") val testcontainersVersion = "0.41.4" val circeVersion = "0.14.10" diff --git a/magnum-pg/src/main/scala/com/augustnagro/magnum/pg/PgCodec.scala b/magnum-pg/src/main/scala/com/augustnagro/magnum/pg/PgCodec.scala index 8f70401..6e9ccab 100644 --- a/magnum-pg/src/main/scala/com/augustnagro/magnum/pg/PgCodec.scala +++ b/magnum-pg/src/main/scala/com/augustnagro/magnum/pg/PgCodec.scala @@ -15,7 +15,7 @@ import org.postgresql.util.PGInterval import java.sql import java.sql.{JDBCType, PreparedStatement, ResultSet, Types} import scala.reflect.ClassTag -import scala.collection.mutable as m +import scala.collection.{mutable as m} import scala.compiletime.* object PgCodec: @@ -64,6 +64,15 @@ object PgCodec: val arr = aArrayCodec.readArray(jdbcArray.getArray) IArray.unsafeFromArray(arr) finally jdbcArray.free() + def readSingleOption(resultSet: ResultSet, pos: Int): Option[IArray[A]] = + val jdbcArray = resultSet.getArray(pos) + if resultSet.wasNull then None + else + try + val arr = aArrayCodec.readArray(jdbcArray.getArray) + Some(IArray.unsafeFromArray(arr)) + finally jdbcArray.free() + def writeSingle(entity: IArray[A], ps: PreparedStatement, pos: Int): Unit = ps.setObject(pos, entity) @@ -81,6 +90,14 @@ object PgCodec: val arr = aArrayCodec.readArray(jdbcArray.getArray) IArray.unsafeFromArray(arr) finally jdbcArray.free() + def readSingleOption(resultSet: ResultSet, pos: Int): Option[IArray[A]] = + val jdbcArray = resultSet.getArray(pos) + if resultSet.wasNull then None + else + try + val arr = aArrayCodec.readArray(jdbcArray.getArray) + Some(IArray.unsafeFromArray(arr)) + finally jdbcArray.free() def writeSingle(entity: IArray[A], ps: PreparedStatement, pos: Int): Unit = val arr = entity.iterator.map(aArrayCodec.toArrayObj).toArray val jdbcArr = @@ -99,6 +116,12 @@ object PgCodec: val jdbcArray = resultSet.getArray(pos) try aArrayCodec.readArray(jdbcArray.getArray) finally jdbcArray.free() + def readSingleOption(resultSet: ResultSet, pos: Int): Option[Array[A]] = + val jdbcArray = resultSet.getArray(pos) + if resultSet.wasNull then None + else + try Some(aArrayCodec.readArray(jdbcArray.getArray)) + finally jdbcArray.free() def writeSingle(entity: Array[A], ps: PreparedStatement, pos: Int): Unit = ps.setObject(pos, entity) @@ -114,6 +137,12 @@ object PgCodec: val jdbcArray = resultSet.getArray(pos) try aArrayCodec.readArray(jdbcArray.getArray) finally jdbcArray.free() + def readSingleOption(resultSet: ResultSet, pos: Int): Option[Array[A]] = + val jdbcArray = resultSet.getArray(pos) + if resultSet.wasNull then None + else + try Some(aArrayCodec.readArray(jdbcArray.getArray)) + finally jdbcArray.free() def writeSingle(entity: Array[A], ps: PreparedStatement, pos: Int): Unit = val arr = entity.iterator.map(aArrayCodec.toArrayObj).toArray val jdbcArr = @@ -133,6 +162,14 @@ object PgCodec: val arr = aArrayCodec.readArray(jdbcArray.getArray) List.from(arr) finally jdbcArray.free() + def readSingleOption(resultSet: ResultSet, pos: Int): Option[Seq[A]] = + val jdbcArray = resultSet.getArray(pos) + if resultSet.wasNull then None + else + try + val arr = aArrayCodec.readArray(jdbcArray.getArray) + Some(List.from(arr)) + finally jdbcArray.free() def writeSingle(entity: Seq[A], ps: PreparedStatement, pos: Int): Unit = val arr = entity.iterator.map(aArrayCodec.toArrayObj).toArray val jdbcArr = @@ -152,6 +189,14 @@ object PgCodec: val arr = aArrayCodec.readArray(jdbcArray.getArray) List.from(arr) finally jdbcArray.free() + def readSingleOption(resultSet: ResultSet, pos: Int): Option[List[A]] = + val jdbcArray = resultSet.getArray(pos) + if resultSet.wasNull then None + else + try + val arr = aArrayCodec.readArray(jdbcArray.getArray) + Some(List.from(arr)) + finally jdbcArray.free() def writeSingle(entity: List[A], ps: PreparedStatement, pos: Int): Unit = val arr = entity.iterator.map(aArrayCodec.toArrayObj).toArray val jdbcArr = @@ -171,6 +216,14 @@ object PgCodec: val arr = aArrayCodec.readArray(jdbcArray.getArray) Vector.from(arr) finally jdbcArray.free() + def readSingleOption(resultSet: ResultSet, pos: Int): Option[Vector[A]] = + val jdbcArray = resultSet.getArray(pos) + if resultSet.wasNull then None + else + try + val arr = aArrayCodec.readArray(jdbcArray.getArray) + Some(Vector.from(arr)) + finally jdbcArray.free() def writeSingle(entity: Vector[A], ps: PreparedStatement, pos: Int): Unit = val arr = entity.iterator.map(aArrayCodec.toArrayObj).toArray val jdbcArr = @@ -191,6 +244,14 @@ object PgCodec: val arr = aArrayCodec.readArray(jdbcArray.getArray) m.Buffer.from(arr) finally jdbcArray.free() + def readSingleOption(resultSet: ResultSet, pos: Int): Option[m.Buffer[A]] = + val jdbcArray = resultSet.getArray(pos) + if resultSet.wasNull then None + else + try + val arr = aArrayCodec.readArray(jdbcArray.getArray) + Some(m.Buffer.from(arr)) + finally jdbcArray.free() def writeSingle( entity: m.Buffer[A], ps: PreparedStatement, @@ -205,6 +266,10 @@ object PgCodec: val cols: IArray[Int] = IArray(Types.JAVA_OBJECT) def readSingle(resultSet: ResultSet, pos: Int): PGbox = resultSet.getObject(pos, classOf[PGbox]) + def readSingleOption(resultSet: ResultSet, pos: Int): Option[PGbox] = + val res = resultSet.getObject(pos, classOf[PGbox]) + if resultSet.wasNull then None + else Some(res) def writeSingle(entity: PGbox, ps: PreparedStatement, pos: Int): Unit = ps.setObject(pos, entity) @@ -213,6 +278,10 @@ object PgCodec: val cols: IArray[Int] = IArray(Types.JAVA_OBJECT) def readSingle(resultSet: ResultSet, pos: Int): PGcircle = resultSet.getObject(pos, classOf[PGcircle]) + def readSingleOption(resultSet: ResultSet, pos: Int): Option[PGcircle] = + val res = resultSet.getObject(pos, classOf[PGcircle]) + if resultSet.wasNull then None + else Some(res) def writeSingle(entity: PGcircle, ps: PreparedStatement, pos: Int): Unit = ps.setObject(pos, entity) @@ -221,6 +290,10 @@ object PgCodec: val cols: IArray[Int] = IArray(Types.JAVA_OBJECT) def readSingle(resultSet: ResultSet, pos: Int): PGInterval = resultSet.getObject(pos, classOf[PGInterval]) + def readSingleOption(resultSet: ResultSet, pos: Int): Option[PGInterval] = + val res = resultSet.getObject(pos, classOf[PGInterval]) + if resultSet.wasNull then None + else Some(res) def writeSingle(entity: PGInterval, ps: PreparedStatement, pos: Int): Unit = ps.setObject(pos, entity) @@ -229,6 +302,10 @@ object PgCodec: val cols: IArray[Int] = IArray(Types.JAVA_OBJECT) def readSingle(resultSet: ResultSet, pos: Int): PGline = resultSet.getObject(pos, classOf[PGline]) + def readSingleOption(resultSet: ResultSet, pos: Int): Option[PGline] = + val res = resultSet.getObject(pos, classOf[PGline]) + if resultSet.wasNull then None + else Some(res) def writeSingle(entity: PGline, ps: PreparedStatement, pos: Int): Unit = ps.setObject(pos, entity) @@ -237,6 +314,10 @@ object PgCodec: val cols: IArray[Int] = IArray(Types.JAVA_OBJECT) def readSingle(resultSet: ResultSet, pos: Int): PGlseg = resultSet.getObject(pos, classOf[PGlseg]) + def readSingleOption(resultSet: ResultSet, pos: Int): Option[PGlseg] = + val res = resultSet.getObject(pos, classOf[PGlseg]) + if resultSet.wasNull then None + else Some(res) def writeSingle(entity: PGlseg, ps: PreparedStatement, pos: Int): Unit = ps.setObject(pos, entity) @@ -245,6 +326,10 @@ object PgCodec: val cols: IArray[Int] = IArray(Types.JAVA_OBJECT) def readSingle(resultSet: ResultSet, pos: Int): PGpath = resultSet.getObject(pos, classOf[PGpath]) + def readSingleOption(resultSet: ResultSet, pos: Int): Option[PGpath] = + val res = resultSet.getObject(pos, classOf[PGpath]) + if resultSet.wasNull then None + else Some(res) def writeSingle(entity: PGpath, ps: PreparedStatement, pos: Int): Unit = ps.setObject(pos, entity) @@ -253,6 +338,10 @@ object PgCodec: val cols: IArray[Int] = IArray(Types.JAVA_OBJECT) def readSingle(resultSet: ResultSet, pos: Int): PGpoint = resultSet.getObject(pos, classOf[PGpoint]) + def readSingleOption(resultSet: ResultSet, pos: Int): Option[PGpoint] = + val res = resultSet.getObject(pos, classOf[PGpoint]) + if resultSet.wasNull then None + else Some(res) def writeSingle(entity: PGpoint, ps: PreparedStatement, pos: Int): Unit = ps.setObject(pos, entity) @@ -261,6 +350,10 @@ object PgCodec: val cols: IArray[Int] = IArray(Types.JAVA_OBJECT) def readSingle(resultSet: ResultSet, pos: Int): PGpolygon = resultSet.getObject(pos, classOf[PGpolygon]) + def readSingleOption(resultSet: ResultSet, pos: Int): Option[PGpolygon] = + val res = resultSet.getObject(pos, classOf[PGpolygon]) + if resultSet.wasNull then None + else Some(res) def writeSingle(entity: PGpolygon, ps: PreparedStatement, pos: Int): Unit = ps.setObject(pos, entity) end PgCodec diff --git a/magnum-pg/src/main/scala/com/augustnagro/magnum/pg/json/JsonBDbCodec.scala b/magnum-pg/src/main/scala/com/augustnagro/magnum/pg/json/JsonBDbCodec.scala index 615336f..dd0faf5 100644 --- a/magnum-pg/src/main/scala/com/augustnagro/magnum/pg/json/JsonBDbCodec.scala +++ b/magnum-pg/src/main/scala/com/augustnagro/magnum/pg/json/JsonBDbCodec.scala @@ -16,15 +16,17 @@ trait JsonBDbCodec[A] extends DbCodec[A]: override val cols: IArray[Int] = IArray(Types.OTHER) override def readSingle(resultSet: ResultSet, pos: Int): A = + decode(resultSet.getString(pos)) + + override def readSingleOption(resultSet: ResultSet, pos: Int): Option[A] = val rawJson = resultSet.getString(pos) - if rawJson eq null then null.asInstanceOf[A] - else decode(rawJson) + if rawJson == null then None + else Some(decode(rawJson)) override def writeSingle(entity: A, ps: PreparedStatement, pos: Int): Unit = val jsonObject = PGobject() jsonObject.setType("jsonb") - val encoded = if entity == null then null else encode(entity) - jsonObject.setValue(encoded) + jsonObject.setValue(encode(entity)) ps.setObject(pos, jsonObject) end JsonBDbCodec diff --git a/magnum-pg/src/main/scala/com/augustnagro/magnum/pg/json/JsonDbCodec.scala b/magnum-pg/src/main/scala/com/augustnagro/magnum/pg/json/JsonDbCodec.scala index 634a2a4..96c2b81 100644 --- a/magnum-pg/src/main/scala/com/augustnagro/magnum/pg/json/JsonDbCodec.scala +++ b/magnum-pg/src/main/scala/com/augustnagro/magnum/pg/json/JsonDbCodec.scala @@ -16,15 +16,17 @@ trait JsonDbCodec[A] extends DbCodec[A]: override val cols: IArray[Int] = IArray(Types.OTHER) override def readSingle(resultSet: ResultSet, pos: Int): A = + decode(resultSet.getString(pos)) + + override def readSingleOption(resultSet: ResultSet, pos: Int): Option[A] = val rawJson = resultSet.getString(pos) - if rawJson eq null then null.asInstanceOf[A] - else decode(rawJson) + if rawJson == null then None + else Some(decode(rawJson)) override def writeSingle(entity: A, ps: PreparedStatement, pos: Int): Unit = val jsonObject = PGobject() jsonObject.setType("json") - val encoded = if entity == null then null else encode(entity) - jsonObject.setValue(encoded) + jsonObject.setValue(encode(entity)) ps.setObject(pos, jsonObject) end JsonDbCodec diff --git a/magnum-pg/src/main/scala/com/augustnagro/magnum/pg/xml/XmlDbCodec.scala b/magnum-pg/src/main/scala/com/augustnagro/magnum/pg/xml/XmlDbCodec.scala index 00940f9..49be270 100644 --- a/magnum-pg/src/main/scala/com/augustnagro/magnum/pg/xml/XmlDbCodec.scala +++ b/magnum-pg/src/main/scala/com/augustnagro/magnum/pg/xml/XmlDbCodec.scala @@ -16,15 +16,17 @@ trait XmlDbCodec[A] extends DbCodec[A]: override val cols: IArray[Int] = IArray(Types.SQLXML) override def readSingle(resultSet: ResultSet, pos: Int): A = + decode(resultSet.getString(pos)) + + override def readSingleOption(resultSet: ResultSet, pos: Int): Option[A] = val xmlString = resultSet.getString(pos) - if xmlString == null then null.asInstanceOf[A] - else decode(xmlString) + if xmlString == null then None + else Some(decode(xmlString)) override def writeSingle(entity: A, ps: PreparedStatement, pos: Int): Unit = val xmlObject = PGobject() xmlObject.setType("xml") - val encoded = if entity == null then null else encode(entity) - xmlObject.setValue(encoded) + xmlObject.setValue(encode(entity)) ps.setObject(pos, xmlObject) end XmlDbCodec diff --git a/magnum-pg/src/test/scala/PgCodecTests.scala b/magnum-pg/src/test/scala/PgCodecTests.scala index b16128d..4a67354 100644 --- a/magnum-pg/src/test/scala/PgCodecTests.scala +++ b/magnum-pg/src/test/scala/PgCodecTests.scala @@ -96,11 +96,11 @@ class PgCodecTests extends FunSuite, TestContainersFixtures: test("select all MagUser"): connect(ds()): - assertEquals(userRepo.findAll, allUsers) + assert(userRepo.findAll == allUsers) test("select all MagCar"): connect(ds()): - assertEquals(carRepo.findAll, allCars) + assert(carRepo.findAll == allCars) test("insert MagUser"): connect(ds()): @@ -124,7 +124,7 @@ class PgCodecTests extends FunSuite, TestContainersFixtures: ) userRepo.insert(u) val dbU = userRepo.findById(3L).get - assertEquals(dbU, u) + assert(dbU == u) test("insert MagCar"): connect(ds()): @@ -141,7 +141,7 @@ class PgCodecTests extends FunSuite, TestContainersFixtures: ) carRepo.insert(c) val dbC = carRepo.findById(3L).get - assertEquals(dbC, c) + assert(dbC == c) test("update MagUser arrays"): connect(ds()): @@ -158,7 +158,7 @@ class PgCodecTests extends FunSuite, TestContainersFixtures: sql"UPDATE mag_car SET text_color_map = $newTextColorMap WHERE id = 2".update .run() val newCar = carRepo.findById(2L).get - assertEquals(newCar.textColorMap, newTextColorMap) + assert(newCar.textColorMap == newTextColorMap) test("MagCar xml string values"): connect(ds()): @@ -170,7 +170,14 @@ class PgCodecTests extends FunSuite, TestContainersFixtures: .map(_.elem.toString) val expected = allCars.flatMap(_.myXml).map(_.elem.toString) println(found) - assertEquals(found, expected) + assert(found == expected) + + test("where = ANY()"): + connect(ds()): + val ids = Vector(1L, 2L) + val cars = + sql"SELECT * FROM mag_car WHERE id = ANY($ids)".query[MagCar].run() + assert(cars == allCars) val pgContainer = ForAllContainerFixture( PostgreSQLContainer diff --git a/magnum/src/main/scala/com/augustnagro/magnum/ClickhouseDbType.scala b/magnum/src/main/scala/com/augustnagro/magnum/ClickhouseDbType.scala index c65e42b..f8b965a 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/ClickhouseDbType.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/ClickhouseDbType.scala @@ -2,6 +2,7 @@ package com.augustnagro.magnum import java.sql.{Connection, PreparedStatement, ResultSet, Statement} import java.time.OffsetDateTime +import java.util.StringJoiner import scala.collection.View import scala.deriving.Mirror import scala.reflect.ClassTag @@ -33,17 +34,18 @@ object ClickhouseDbType extends DbType: val ecInsertKeys = ecElemNamesSql.mkString("(", ", ", ")") val countSql = s"SELECT count(*) FROM $tableNameSql" - val countQuery = Frag(countSql).query[Long] + val countQuery = Frag(countSql, Vector.empty, FragWriter.empty).query[Long] val existsByIdSql = s"SELECT 1 FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val findAllSql = s"SELECT $selectKeys FROM $tableNameSql" - val findAllQuery = Frag(findAllSql).query[E] + val findAllQuery = Frag(findAllSql, Vector.empty, FragWriter.empty).query[E] val findByIdSql = s"SELECT $selectKeys FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val deleteByIdSql = s"DELETE FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val truncateSql = s"TRUNCATE TABLE $tableNameSql" - val truncateUpdate = Frag(truncateSql).update + val truncateUpdate = + Frag(truncateSql, Vector.empty, FragWriter.empty).update val insertSql = s"INSERT INTO $tableNameSql $ecInsertKeys VALUES (${ecCodec.queryRepr})" @@ -63,10 +65,7 @@ object ClickhouseDbType extends DbType: def findAll(using DbCon): Vector[E] = findAllQuery.run() def findAll(spec: Spec[E])(using DbCon): Vector[E] = - val f = spec.build - Frag(s"SELECT * FROM $tableNameSql ${f.sqlString}", f.params, f.writer) - .query[E] - .run() + SpecImpl.Default.findAll(spec, tableNameSql) def findById(id: ID)(using DbCon): Option[E] = Frag(findByIdSql, IArray(id), idWriter(id)) diff --git a/magnum/src/main/scala/com/augustnagro/magnum/DbCodec.scala b/magnum/src/main/scala/com/augustnagro/magnum/DbCodec.scala index 41248c8..d069a48 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/DbCodec.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/DbCodec.scala @@ -50,6 +50,12 @@ trait DbCodec[E]: */ def readSingle(resultSet: ResultSet): E = readSingle(resultSet, 1) + /** Read an Option[E] from the ResultSet starting at position `pos` and ending + * after reading `cols` number of columns. Make sure the ResultSet is in a + * valid state (ie, ResultSet::next has been called). + */ + def readSingleOption(resultSet: ResultSet, pos: Int): Option[E] + /** Build every row in the ResultSet into a sequence of E. The ResultSet * should be in its initial position before calling (ie, ResultSet::next not * called). @@ -79,6 +85,8 @@ trait DbCodec[E]: val cols: IArray[Int] = self.cols def readSingle(rs: ResultSet, pos: Int): E2 = to(self.readSingle(rs, pos)) + def readSingleOption(rs: ResultSet, pos: Int): Option[E2] = + self.readSingleOption(rs, pos).map(to) def writeSingle(e: E2, ps: PreparedStatement, pos: Int): Unit = self.writeSingle(from(e), ps, pos) def queryRepr: String = self.queryRepr @@ -91,6 +99,8 @@ object DbCodec: given AnyCodec: DbCodec[Any] with val cols: IArray[Int] = IArray(Types.JAVA_OBJECT) def readSingle(rs: ResultSet, pos: Int): Any = rs.getObject(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[Any] = + Option(rs.getObject(pos)) def writeSingle(a: Any, ps: PreparedStatement, pos: Int): Unit = ps.setObject(pos, a) def queryRepr: String = "?" @@ -98,6 +108,8 @@ object DbCodec: given StringCodec: DbCodec[String] with val cols: IArray[Int] = IArray(Types.VARCHAR) def readSingle(rs: ResultSet, pos: Int): String = rs.getString(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[String] = + Option(rs.getString(pos)) def writeSingle(s: String, ps: PreparedStatement, pos: Int): Unit = ps.setString(pos, s) def queryRepr: String = "?" @@ -105,6 +117,10 @@ object DbCodec: given BooleanCodec: DbCodec[Boolean] with val cols: IArray[Int] = IArray(Types.BOOLEAN) def readSingle(rs: ResultSet, pos: Int): Boolean = rs.getBoolean(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[Boolean] = + val res = rs.getBoolean(pos) + if rs.wasNull then None + else Some(res) def writeSingle(b: Boolean, ps: PreparedStatement, pos: Int): Unit = ps.setBoolean(pos, b) def queryRepr: String = "?" @@ -112,6 +128,10 @@ object DbCodec: given ByteCodec: DbCodec[Byte] with val cols: IArray[Int] = IArray(Types.TINYINT) def readSingle(rs: ResultSet, pos: Int): Byte = rs.getByte(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[Byte] = + val res = rs.getByte(pos) + if rs.wasNull then None + else Some(res) def writeSingle(b: Byte, ps: PreparedStatement, pos: Int): Unit = ps.setByte(pos, b) def queryRepr: String = "?" @@ -119,6 +139,10 @@ object DbCodec: given ShortCodec: DbCodec[Short] with val cols: IArray[Int] = IArray(Types.SMALLINT) def readSingle(rs: ResultSet, pos: Int): Short = rs.getShort(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[Short] = + val res = rs.getShort(pos) + if rs.wasNull then None + else Some(res) def writeSingle(s: Short, ps: PreparedStatement, pos: Int): Unit = ps.setShort(pos, s) def queryRepr: String = "?" @@ -126,6 +150,10 @@ object DbCodec: given IntCodec: DbCodec[Int] with val cols: IArray[Int] = IArray(Types.INTEGER) def readSingle(rs: ResultSet, pos: Int): Int = rs.getInt(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[Int] = + val res = rs.getInt(pos) + if rs.wasNull then None + else Some(res) def writeSingle(i: Int, ps: PreparedStatement, pos: Int): Unit = ps.setInt(pos, i) def queryRepr: String = "?" @@ -133,6 +161,10 @@ object DbCodec: given LongCodec: DbCodec[Long] with val cols: IArray[Int] = IArray(Types.BIGINT) def readSingle(rs: ResultSet, pos: Int): Long = rs.getLong(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[Long] = + val res = rs.getLong(pos) + if rs.wasNull then None + else Some(res) def writeSingle(l: Long, ps: PreparedStatement, pos: Int): Unit = ps.setLong(pos, l) def queryRepr: String = "?" @@ -140,6 +172,10 @@ object DbCodec: given FloatCodec: DbCodec[Float] with val cols: IArray[Int] = IArray(Types.REAL) def readSingle(rs: ResultSet, pos: Int): Float = rs.getFloat(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[Float] = + val res = rs.getFloat(pos) + if rs.wasNull then None + else Some(res) def writeSingle(f: Float, ps: PreparedStatement, pos: Int): Unit = ps.setFloat(pos, f) def queryRepr: String = "?" @@ -147,6 +183,10 @@ object DbCodec: given DoubleCodec: DbCodec[Double] with val cols: IArray[Int] = IArray(Types.DOUBLE) def readSingle(rs: ResultSet, pos: Int): Double = rs.getDouble(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[Double] = + val res = rs.getDouble(pos) + if rs.wasNull then None + else Some(res) def writeSingle(d: Double, ps: PreparedStatement, pos: Int): Unit = ps.setDouble(pos, d) def queryRepr: String = "?" @@ -154,6 +194,8 @@ object DbCodec: given ByteArrayCodec: DbCodec[Array[Byte]] with val cols: IArray[Int] = IArray(Types.BINARY) def readSingle(rs: ResultSet, pos: Int): Array[Byte] = rs.getBytes(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[Array[Byte]] = + Option(rs.getBytes(pos)) def writeSingle(bytes: Array[Byte], ps: PreparedStatement, pos: Int): Unit = ps.setBytes(pos, bytes) def queryRepr: String = "?" @@ -162,6 +204,8 @@ object DbCodec: val cols: IArray[Int] = IArray(Types.BINARY) def readSingle(rs: ResultSet, pos: Int): IArray[Byte] = IArray.unsafeFromArray(rs.getBytes(pos)) + def readSingleOption(rs: ResultSet, pos: Int): Option[IArray[Byte]] = + ByteArrayCodec.readSingleOption(rs, pos).map(IArray.unsafeFromArray) def writeSingle( bytes: IArray[Byte], ps: PreparedStatement, @@ -173,6 +217,8 @@ object DbCodec: given SqlDateCodec: DbCodec[java.sql.Date] with val cols: IArray[Int] = IArray(Types.DATE) def readSingle(rs: ResultSet, pos: Int): java.sql.Date = rs.getDate(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[java.sql.Date] = + Option(rs.getDate(pos)) def writeSingle( date: java.sql.Date, ps: PreparedStatement, @@ -184,6 +230,8 @@ object DbCodec: val cols: IArray[Int] = IArray(Types.TIME) def readSingle(rs: ResultSet, pos: Int): java.sql.Time = rs.getTime(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[java.sql.Time] = + Option(rs.getTime(pos)) def writeSingle( time: java.sql.Time, ps: PreparedStatement, @@ -195,6 +243,8 @@ object DbCodec: val cols: IArray[Int] = IArray(Types.TIMESTAMP) def readSingle(rs: ResultSet, pos: Int): java.sql.Timestamp = rs.getTimestamp(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[java.sql.Timestamp] = + Option(rs.getTimestamp(pos)) def writeSingle( t: java.sql.Timestamp, ps: PreparedStatement, @@ -206,6 +256,10 @@ object DbCodec: val cols: IArray[Int] = IArray(Types.TIMESTAMP_WITH_TIMEZONE) def readSingle(rs: ResultSet, pos: Int): OffsetDateTime = rs.getObject(pos, classOf[OffsetDateTime]) + def readSingleOption(rs: ResultSet, pos: Int): Option[OffsetDateTime] = + val res = rs.getObject(pos, classOf[OffsetDateTime]) + if rs.wasNull then None + else Some(res) def writeSingle(dt: OffsetDateTime, ps: PreparedStatement, pos: Int): Unit = ps.setObject(pos, dt) def queryRepr: String = "?" @@ -213,6 +267,10 @@ object DbCodec: given SqlRefCodec: DbCodec[java.sql.Ref] with val cols: IArray[Int] = IArray(Types.REF) def readSingle(rs: ResultSet, pos: Int): java.sql.Ref = rs.getRef(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[java.sql.Ref] = + val res = rs.getRef(pos) + if rs.wasNull then None + else Some(res) def writeSingle(ref: java.sql.Ref, ps: PreparedStatement, pos: Int): Unit = ps.setRef(pos, ref) def queryRepr: String = "?" @@ -220,6 +278,10 @@ object DbCodec: given SqlBlobCodec: DbCodec[java.sql.Blob] with val cols: IArray[Int] = IArray(Types.BLOB) def readSingle(rs: ResultSet, pos: Int): java.sql.Blob = rs.getBlob(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[java.sql.Blob] = + val res = rs.getBlob(pos) + if rs.wasNull then None + else Some(res) def writeSingle(b: java.sql.Blob, ps: PreparedStatement, pos: Int): Unit = ps.setBlob(pos, b) def queryRepr: String = "?" @@ -227,6 +289,10 @@ object DbCodec: given SqlClobCodec: DbCodec[java.sql.Clob] with val cols: IArray[Int] = IArray(Types.CLOB) def readSingle(rs: ResultSet, pos: Int): java.sql.Clob = rs.getClob(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[java.sql.Clob] = + val res = rs.getClob(pos) + if rs.wasNull then None + else Some(res) def writeSingle(c: java.sql.Clob, ps: PreparedStatement, pos: Int): Unit = ps.setClob(pos, c) def queryRepr: String = "?" @@ -234,6 +300,8 @@ object DbCodec: given URLCodec: DbCodec[URL] with val cols: IArray[Int] = IArray(Types.VARCHAR) def readSingle(rs: ResultSet, pos: Int): URL = rs.getURL(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[URL] = + Option(rs.getURL(pos)) def writeSingle(url: URL, ps: PreparedStatement, pos: Int): Unit = ps.setURL(pos, url) def queryRepr: String = "?" @@ -241,6 +309,8 @@ object DbCodec: given RowIdCodec: DbCodec[java.sql.RowId] with val cols: IArray[Int] = IArray(Types.ROWID) def readSingle(rs: ResultSet, pos: Int): java.sql.RowId = rs.getRowId(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[java.sql.RowId] = + Option(rs.getRowId(pos)) def writeSingle( rowId: java.sql.RowId, ps: PreparedStatement, @@ -252,6 +322,10 @@ object DbCodec: given SqlNClobCodec: DbCodec[java.sql.NClob] with val cols: IArray[Int] = IArray(Types.NCLOB) def readSingle(rs: ResultSet, pos: Int): java.sql.NClob = rs.getNClob(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[java.sql.NClob] = + val res = rs.getNClob(pos) + if rs.wasNull then None + else Some(res) def writeSingle(nc: java.sql.NClob, ps: PreparedStatement, pos: Int): Unit = ps.setNClob(pos, nc) def queryRepr: String = "?" @@ -259,6 +333,10 @@ object DbCodec: given SqlXmlCodec: DbCodec[java.sql.SQLXML] with val cols: IArray[Int] = IArray(Types.SQLXML) def readSingle(rs: ResultSet, pos: Int): java.sql.SQLXML = rs.getSQLXML(pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[java.sql.SQLXML] = + val res = rs.getSQLXML(pos) + if rs.wasNull then None + else Some(res) def writeSingle(s: java.sql.SQLXML, ps: PreparedStatement, pos: Int): Unit = ps.setSQLXML(pos, s) def queryRepr: String = "?" @@ -267,6 +345,11 @@ object DbCodec: val cols: IArray[Int] = IArray(Types.NUMERIC) def readSingle(rs: ResultSet, pos: Int): java.math.BigDecimal = rs.getBigDecimal(pos) + def readSingleOption( + rs: ResultSet, + pos: Int + ): Option[java.math.BigDecimal] = + Option(rs.getBigDecimal(pos)) def writeSingle( bd: java.math.BigDecimal, ps: PreparedStatement, @@ -278,9 +361,11 @@ object DbCodec: given ScalaBigDecimalCodec: DbCodec[scala.math.BigDecimal] with val cols: IArray[Int] = IArray(Types.NUMERIC) def readSingle(rs: ResultSet, pos: Int): scala.math.BigDecimal = - rs.getBigDecimal(pos) match - case null => null - case x => scala.math.BigDecimal(x) + scala.math.BigDecimal(rs.getBigDecimal(pos)) + def readSingleOption(rs: ResultSet, pos: Int): Option[BigDecimal] = + JavaBigDecimalCodec + .readSingleOption(rs, pos) + .map(scala.math.BigDecimal.apply) def writeSingle( bd: scala.math.BigDecimal, ps: PreparedStatement, @@ -294,15 +379,19 @@ object DbCodec: val cols: IArray[Int] = IArray(Types.OTHER) def readSingle(rs: ResultSet, pos: Int): UUID = rs.getObject(pos, classOf[UUID]) + def readSingleOption(rs: ResultSet, pos: Int): Option[UUID] = + val res = rs.getObject(pos, classOf[UUID]) + if rs.wasNull then None + else Some(res) def writeSingle(entity: UUID, ps: PreparedStatement, pos: Int): Unit = ps.setObject(pos, entity) given OptionCodec[A](using codec: DbCodec[A]): DbCodec[Option[A]] with def cols: IArray[Int] = codec.cols def readSingle(rs: ResultSet, pos: Int): Option[A] = - val a = codec.readSingle(rs, pos) - if rs.wasNull then None - else Some(a) + codec.readSingleOption(rs, pos) + def readSingleOption(rs: ResultSet, pos: Int): Option[Option[A]] = + Some(codec.readSingleOption(rs, pos)) def writeSingle(opt: Option[A], ps: PreparedStatement, pos: Int): Unit = opt match case Some(a) => @@ -320,6 +409,12 @@ object DbCodec: aCodec.readSingle(rs, pos), bCodec.readSingle(rs, pos + aCodec.cols.length) ) + def readSingleOption(rs: ResultSet, pos: Int): Option[(A, B)] = + val a = aCodec.readSingleOption(rs, pos) + val b = bCodec.readSingleOption(rs, pos + aCodec.cols.length) + (a, b) match + case (Some(a), Some(b)) => Some((a, b)) + case _ => None def writeSingle(tup: (A, B), ps: PreparedStatement, pos: Int): Unit = aCodec.writeSingle(tup._1, ps, pos) bCodec.writeSingle(tup._2, ps, pos + aCodec.cols.length) @@ -340,6 +435,16 @@ object DbCodec: i += bCodec.cols.length val c = cCodec.readSingle(rs, i) (a, b, c) + def readSingleOption(rs: ResultSet, pos: Int): Option[(A, B, C)] = + var i = pos + val a = aCodec.readSingleOption(rs, i) + i += aCodec.cols.length + val b = bCodec.readSingleOption(rs, i) + i += bCodec.cols.length + val c = cCodec.readSingleOption(rs, i) + (a, b, c) match + case (Some(a), Some(b), Some(c)) => Some((a, b, c)) + case _ => None def writeSingle(tup: (A, B, C), ps: PreparedStatement, pos: Int): Unit = var i = pos aCodec.writeSingle(tup._1, ps, i) @@ -369,6 +474,18 @@ object DbCodec: i += cCodec.cols.length val d = dCodec.readSingle(rs, i) (a, b, c, d) + def readSingleOption(rs: ResultSet, pos: Int): Option[(A, B, C, D)] = + var i = pos + val a = aCodec.readSingleOption(rs, i) + i += aCodec.cols.length + val b = bCodec.readSingleOption(rs, i) + i += bCodec.cols.length + val c = cCodec.readSingleOption(rs, i) + i += cCodec.cols.length + val d = dCodec.readSingleOption(rs, i) + (a, b, c, d) match + case (Some(a), Some(b), Some(c), Some(d)) => Some((a, b, c, d)) + case _ => None def writeSingle(tup: (A, B, C, D), ps: PreparedStatement, pos: Int): Unit = var i = pos aCodec.writeSingle(tup._1, ps, i) @@ -406,6 +523,10 @@ object DbCodec: ${ productReadSingle[E, mets]('{ rs }, mp, Vector.empty, '{ pos }) } + def readSingleOption(rs: ResultSet, pos: Int): Option[E] = + ${ + productReadOption[E, mets]('{ rs }, mp, Vector.empty, '{ pos }) + } def writeSingle(e: E, ps: PreparedStatement, pos: Int): Unit = ${ productWriteSingle[E, mets]('{ e }, '{ ps }, '{ pos }, '{ 0 }) @@ -434,6 +555,15 @@ object DbCodec: throw IllegalArgumentException( str + " not convertible to " + $melExpr ) + def readSingleOption(rs: ResultSet, pos: Int): Option[E] = + Option(rs.getString(pos)).map(str => + nameMap.find((name, _) => name == str) match + case Some((_, v)) => v + case None => + throw IllegalArgumentException( + str + " not convertible to " + $melExpr + ) + ) def writeSingle(entity: E, ps: PreparedStatement, pos: Int): Unit = nameMap.find((_, v) => v == entity) match case Some((k, _)) => ps.setString(pos, k) @@ -541,6 +671,69 @@ object DbCodec: end match end productReadSingle + private def productReadOption[E: Type, Mets: Type]( + rs: Expr[ResultSet], + m: Expr[Mirror.ProductOf[E]], + res: Vector[Expr[Any]], + pos: Expr[Int] + )(using Quotes): Expr[Option[E]] = + import quotes.reflect.* + Type.of[Mets] match + case '[met *: metTail] => + Expr.summon[DbCodec[met]] match + case Some(codecExpr) => + '{ + val posValue = $pos + val codec = $codecExpr + codec.readSingleOption($rs, posValue) match + case Some(metValue) => + val newPos = posValue + codec.cols.length + ${ + productReadOption[E, metTail]( + rs, + m, + res :+ '{ metValue }, + '{ newPos } + ) + } + case None => None + } + case None => + Expr.summon[ClassTag[met]] match + case Some(clsTagExpr) => + report.info( + s"Could not find DbCodec for ${TypeRepr.of[met].show}. Defaulting to ResultSet::[get|set]Object" + ) + '{ + val posValue = $pos + val metValue = $rs.getObject( + posValue, + $clsTagExpr.runtimeClass.asInstanceOf[Class[met]] + ) + if $rs.wasNull then None + else + val newPos = posValue + 1 + ${ + productReadOption[E, metTail]( + rs, + m, + res :+ '{ metValue }, + '{ newPos } + ) + } + } + case None => + report.errorAndAbort( + "Could not find DbCodec or ClassTag for ${TypeRepr.of[met].show}" + ) + case '[EmptyTuple] => + '{ + val product = ${ Expr.ofTupleFromSeq(res) } + Some($m.fromProduct(product)) + } + end match + end productReadOption + private def productWriteSingle[E: Type, Mets: Type]( e: Expr[E], ps: Expr[PreparedStatement], diff --git a/magnum/src/main/scala/com/augustnagro/magnum/Frag.scala b/magnum/src/main/scala/com/augustnagro/magnum/Frag.scala index 4396aec..6793acc 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/Frag.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/Frag.scala @@ -6,15 +6,40 @@ import scala.collection.immutable.ArraySeq import scala.util.{Failure, Success, Using} /** Sql fragment */ -case class Frag( - sqlString: String, - params: Seq[Any] = Seq.empty, - writer: FragWriter = Frag.emptyWriter +class Frag( + val sqlString: String, + val params: Seq[Any], + val writer: FragWriter ): def query[E](using reader: DbCodec[E]): Query[E] = Query(this, reader) + def update: Update = Update(this) + + /** For databases like Postgres that support RETURNING statements via + * `getResultSet` + */ def returning[E](using reader: DbCodec[E]): Returning[E] = - Returning(this, reader) + Returning(this, reader, Vector.empty) + + /** For databases that support RETURNING statements via `getGeneratedKeys` + */ + def returningKeys[E](colName: String, xs: String*)(using + reader: DbCodec[E] + ): Returning[E] = + Returning(this, reader, colName +: xs) + + /** For databases that support RETURNING statements via `getGeneratedKeys` + */ + def returningKeys[E](colName: ColumnName, xs: ColumnName*)(using + reader: DbCodec[E] + ): Returning[E] = + Returning(this, reader, (colName +: xs).map(_.queryRepr)) + + /** For databases that support RETURNING statements via `getGeneratedKeys` + */ + def returningKeys[E](colNames: ColumnNames)(using + reader: DbCodec[E] + ): Returning[E] = + Returning(this, reader, colNames.columnNames.map(_.queryRepr)) -object Frag: - private val emptyWriter: FragWriter = (_, _) => 0 +end Frag diff --git a/magnum/src/main/scala/com/augustnagro/magnum/FragWriter.scala b/magnum/src/main/scala/com/augustnagro/magnum/FragWriter.scala index fc72ba7..0ad31e4 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/FragWriter.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/FragWriter.scala @@ -7,3 +7,6 @@ trait FragWriter: * position. */ def write(ps: PreparedStatement, pos: Int): Int + +object FragWriter: + val empty: FragWriter = (_, pos) => pos diff --git a/magnum/src/main/scala/com/augustnagro/magnum/H2DbType.scala b/magnum/src/main/scala/com/augustnagro/magnum/H2DbType.scala index c659c0f..d5495c9 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/H2DbType.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/H2DbType.scala @@ -43,18 +43,19 @@ object H2DbType extends DbType: val insertGenKeys: Array[String] = Array.from(eElemNamesSql) val countSql = s"SELECT count(*) FROM $tableNameSql" - val countQuery = Frag(countSql).query[Long] + val countQuery = Frag(countSql, Vector.empty, FragWriter.empty).query[Long] val existsByIdSql = s"SELECT 1 FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val findAllSql = s"SELECT * FROM $tableNameSql" - val findAllQuery = Frag(findAllSql).query[E] + val findAllQuery = Frag(findAllSql, Vector.empty, FragWriter.empty).query[E] val findByIdSql = s"SELECT * FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val findAllByIdSql = s"SELECT * FROM $tableNameSql WHERE $idName = ANY(?)" val deleteByIdSql = s"DELETE FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val truncateSql = s"TRUNCATE TABLE $tableNameSql" - val truncateUpdate = Frag(truncateSql).update + val truncateUpdate = + Frag(truncateSql, Vector.empty, FragWriter.empty).update val insertSql = s"INSERT INTO $tableNameSql $ecInsertKeys VALUES (${ecCodec.queryRepr})" val updateSql = @@ -79,10 +80,7 @@ object H2DbType extends DbType: def findAll(using DbCon): Vector[E] = findAllQuery.run() def findAll(spec: Spec[E])(using DbCon): Vector[E] = - val f = spec.build - Frag(s"SELECT * FROM $tableNameSql ${f.sqlString}", f.params, f.writer) - .query[E] - .run() + SpecImpl.Default.findAll(spec, tableNameSql) def findById(id: ID)(using DbCon): Option[E] = Frag(findByIdSql, IArray(id), idWriter(id)) diff --git a/magnum/src/main/scala/com/augustnagro/magnum/MySqlDbType.scala b/magnum/src/main/scala/com/augustnagro/magnum/MySqlDbType.scala index ef1668e..55c2448 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/MySqlDbType.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/MySqlDbType.scala @@ -9,6 +9,31 @@ import scala.util.{Failure, Success, Using} object MySqlDbType extends DbType: + private val specImpl = new SpecImpl: + override def sortSql(sort: Sort): String = + val column = sort.column + val nullSort = sort.nullOrder match + case NullOrder.Default => "" + case NullOrder.First => s"$column IS NOT NULL, " + case NullOrder.Last => s"$column IS NULL, " + case _ => throw UnsupportedOperationException() + val dir = sort.direction match + case SortOrder.Default => "" + case SortOrder.Asc => " ASC" + case SortOrder.Desc => " DESC" + case _ => throw UnsupportedOperationException() + nullSort + column + dir + + override def offsetLimitSql( + offset: Option[Long], + limit: Option[Int] + ): Option[String] = + (offset, limit) match + case (Some(o), Some(l)) => Some(s"LIMIT $o, $l") + case (Some(o), None) => Some(s"LIMIT $o, ${Long.MaxValue}") + case (None, Some(l)) => Some(s"LIMIT $l") + case (None, None) => None + def buildRepoDefaults[EC, E, ID]( tableNameSql: String, eElemNames: Seq[String], @@ -43,17 +68,18 @@ object MySqlDbType extends DbType: .asInstanceOf[Seq[DbCodec[Any]]] val countSql = s"SELECT count(*) FROM $tableNameSql" - val countQuery = Frag(countSql).query[Long] + val countQuery = Frag(countSql, Vector.empty, FragWriter.empty).query[Long] val existsByIdSql = s"SELECT 1 FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val findAllSql = s"SELECT * FROM $tableNameSql" - val findAllQuery = Frag(findAllSql).query[E] + val findAllQuery = Frag(findAllSql, Vector.empty, FragWriter.empty).query[E] val findByIdSql = s"SELECT * FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val deleteByIdSql = s"DELETE FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val truncateSql = s"TRUNCATE TABLE $tableNameSql" - val truncateUpdate = Frag(truncateSql).update + val truncateUpdate = + Frag(truncateSql, Vector.empty, FragWriter.empty).update val insertSql = s"INSERT INTO $tableNameSql $ecInsertKeys VALUES (${ecCodec.queryRepr})" val updateSql = @@ -76,10 +102,21 @@ object MySqlDbType extends DbType: def findAll(using DbCon): Vector[E] = findAllQuery.run() def findAll(spec: Spec[E])(using DbCon): Vector[E] = - val f = spec.build - Frag(s"SELECT * FROM $tableNameSql ${f.sqlString}", f.params, f.writer) - .query[E] - .run() + specImpl.findAll(spec, tableNameSql) + + private def sortSql(sort: Sort): String = + val column = sort.column + val nullSort = sort.nullOrder match + case NullOrder.Default => "" + case NullOrder.First => s"$column IS NOT NULL, " + case NullOrder.Last => s"$column IS NULL, " + case _ => throw UnsupportedOperationException() + val dir = sort.direction match + case SortOrder.Default => "" + case SortOrder.Asc => " ASC" + case SortOrder.Desc => " DESC" + case _ => throw UnsupportedOperationException() + nullSort + column + dir def findById(id: ID)(using DbCon): Option[E] = Frag(findByIdSql, IArray(id), idWriter(id)) @@ -134,35 +171,15 @@ object MySqlDbType extends DbType: timed(batchUpdateResult(ps.executeBatch())) def insertReturning(entityCreator: EC)(using con: DbCon): E = - handleQuery(insertAndFindByIdSql, entityCreator): - Using.Manager: use => - val ps = - use(con.connection.prepareStatement(insertSql, insertGenKeys)) - ecCodec.writeSingle(entityCreator, ps) - timed: - ps.executeUpdate() - val rs = use(ps.getGeneratedKeys) - rs.next() - val id = idCodec.readSingle(rs) - // unfortunately, mysql only will return auto_incremented keys. - // it doesn't return default columns, and adding other columns to - // the insertGenKeys array doesn't change this behavior. So we need - // to query by ID after every insert. - findById(id).get + // unfortunately, mysql only will return auto_incremented keys. + // it doesn't return default columns, and adding other columns to + // the insertGenKeys array doesn't change this behavior. + throw UnsupportedOperationException() def insertAllReturning( entityCreators: Iterable[EC] )(using con: DbCon): Vector[E] = - handleQuery(insertAndFindByIdSql, entityCreators): - Using.Manager: use => - val ps = - use(con.connection.prepareStatement(insertSql, insertGenKeys)) - ecCodec.write(entityCreators, ps) - timed: - batchUpdateResult(ps.executeBatch()) - val rs = use(ps.getGeneratedKeys) - val ids = idCodec.read(rs) - ids.map(findById(_).get) + throw UnsupportedOperationException() def update(entity: E)(using con: DbCon): Unit = handleQuery(updateSql, entity): diff --git a/magnum/src/main/scala/com/augustnagro/magnum/NullOrder.scala b/magnum/src/main/scala/com/augustnagro/magnum/NullOrder.scala index 302aff3..e3f16e6 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/NullOrder.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/NullOrder.scala @@ -1,14 +1,8 @@ package com.augustnagro.magnum -trait NullOrder: - def sql: String +trait NullOrder object NullOrder: - object First extends NullOrder: - def sql: String = "NULLS FIRST" - - object Last extends NullOrder: - def sql: String = "NULLS LAST" - - object Empty extends NullOrder: - def sql: String = "" + case object Default extends NullOrder + case object First extends NullOrder + case object Last extends NullOrder diff --git a/magnum/src/main/scala/com/augustnagro/magnum/OracleDbType.scala b/magnum/src/main/scala/com/augustnagro/magnum/OracleDbType.scala index 5858c10..f09d387 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/OracleDbType.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/OracleDbType.scala @@ -8,6 +8,19 @@ import scala.reflect.ClassTag import scala.util.{Failure, Success, Using} object OracleDbType extends DbType: + + private val specImpl = new SpecImpl: + override def offsetLimitSql( + offset: Option[Long], + limit: Option[Int] + ): Option[String] = + (offset, limit) match + case (Some(o), Some(l)) => + Some(s"OFFSET $o ROWS FETCH NEXT $l ROWS ONLY") + case (Some(o), None) => Some(s"OFFSET $o ROWS") + case (None, Some(l)) => Some(s"FETCH NEXT $l ROWS ONLY") + case (None, None) => None + def buildRepoDefaults[EC, E, ID]( tableNameSql: String, eElemNames: Seq[String], @@ -42,17 +55,18 @@ object OracleDbType extends DbType: val insertGenKeys = Array.from(eElemNamesSql) val countSql = s"SELECT count(*) FROM $tableNameSql" - val countQuery = Frag(countSql).query[Long] + val countQuery = Frag(countSql, Vector.empty, FragWriter.empty).query[Long] val existsByIdSql = s"SELECT 1 FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val findAllSql = s"SELECT * FROM $tableNameSql" - val findAllQuery = Frag(findAllSql).query[E] + val findAllQuery = Frag(findAllSql, Vector.empty, FragWriter.empty).query[E] val findByIdSql = s"SELECT * FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val deleteByIdSql = s"DELETE FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val truncateSql = s"TRUNCATE TABLE $tableNameSql" - val truncateUpdate = Frag(truncateSql).update + val truncateUpdate = + Frag(truncateSql, Vector.empty, FragWriter.empty).update val insertSql = s"INSERT INTO $tableNameSql $ecInsertKeys VALUES (${ecCodec.queryRepr})" val updateSql = @@ -74,10 +88,7 @@ object OracleDbType extends DbType: def findAll(using DbCon): Vector[E] = findAllQuery.run() def findAll(spec: Spec[E])(using DbCon): Vector[E] = - val f = spec.build - Frag(s"SELECT * FROM $tableNameSql ${f.sqlString}", f.params, f.writer) - .query[E] - .run() + specImpl.findAll(spec, tableNameSql) def findById(id: ID)(using DbCon): Option[E] = Frag(findByIdSql, IArray(id), idWriter(id)) diff --git a/magnum/src/main/scala/com/augustnagro/magnum/PostgresDbType.scala b/magnum/src/main/scala/com/augustnagro/magnum/PostgresDbType.scala index c8df857..697e01e 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/PostgresDbType.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/PostgresDbType.scala @@ -6,6 +6,7 @@ import scala.collection.View import scala.deriving.Mirror import scala.reflect.ClassTag import scala.util.{Failure, Success, Using} +import java.util.StringJoiner object PostgresDbType extends DbType: @@ -41,11 +42,11 @@ object PostgresDbType extends DbType: .asInstanceOf[Seq[DbCodec[Any]]] val countSql = s"SELECT count(*) FROM $tableNameSql" - val countQuery = Frag(countSql).query[Long] + val countQuery = Frag(countSql, Vector.empty, FragWriter.empty).query[Long] val existsByIdSql = s"SELECT 1 FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val findAllSql = s"SELECT $selectKeys FROM $tableNameSql" - val findAllQuery = Frag(findAllSql).query[E] + val findAllQuery = Frag(findAllSql, Vector.empty, FragWriter.empty).query[E] val findByIdSql = s"SELECT $selectKeys FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val findAllByIdSql = @@ -53,7 +54,8 @@ object PostgresDbType extends DbType: val deleteByIdSql = s"DELETE FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val truncateSql = s"TRUNCATE TABLE $tableNameSql" - val truncateUpdate = Frag(truncateSql).update + val truncateUpdate = + Frag(truncateSql, Vector.empty, FragWriter.empty).update val insertSql = s"INSERT INTO $tableNameSql $ecInsertKeys VALUES (${ecCodec.queryRepr})" val updateSql = @@ -78,10 +80,7 @@ object PostgresDbType extends DbType: def findAll(using DbCon): Vector[E] = findAllQuery.run() def findAll(spec: Spec[E])(using DbCon): Vector[E] = - val f = spec.build - Frag(s"SELECT * FROM $tableNameSql ${f.sqlString}", f.params, f.writer) - .query[E] - .run() + SpecImpl.Default.findAll(spec, tableNameSql) def findById(id: ID)(using DbCon): Option[E] = Frag(findByIdSql, IArray(id), idWriter(id)) diff --git a/magnum/src/main/scala/com/augustnagro/magnum/Query.scala b/magnum/src/main/scala/com/augustnagro/magnum/Query.scala index 351f07d..8d07e9f 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/Query.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/Query.scala @@ -6,7 +6,7 @@ import scala.util.Using.Manager import scala.util.control.NonFatal import scala.util.{Failure, Success, Try, Using} -case class Query[E](frag: Frag, reader: DbCodec[E]): +class Query[E] private[magnum] (val frag: Frag, reader: DbCodec[E]): def run()(using con: DbCon): Vector[E] = handleQuery(frag.sqlString, frag.params): diff --git a/magnum/src/main/scala/com/augustnagro/magnum/Returning.scala b/magnum/src/main/scala/com/augustnagro/magnum/Returning.scala index 5ce24ec..6477961 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/Returning.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/Returning.scala @@ -3,22 +3,17 @@ package com.augustnagro.magnum import scala.util.{Failure, Success, Try, Using} import Using.Manager import java.sql.Statement +import java.sql.ResultSet + +class Returning[E] private[magnum] ( + val frag: Frag, + reader: DbCodec[E], + keyColumns: Iterable[String] +): + private val keyColumsArr = keyColumns.toArray -case class Returning[E](frag: Frag, reader: DbCodec[E]): def run()(using con: DbCon): Vector[E] = - handleQuery(frag.sqlString, frag.params): - Manager: use => - val ps = use(con.connection.prepareStatement(frag.sqlString)) - frag.writer.write(ps, 1) - timed: - val hasResults = ps.execute() - if hasResults then - val rs = use(ps.getResultSet()) - reader.read(rs) - else - throw UnsupportedOperationException( - "No results for RETURNING clause" - ) + withResultSet(reader.read) /** Streaming [[Iterator]]. Set [[fetchSize]] to give the JDBC driver a hint * as to how many rows to fetch per request @@ -26,19 +21,31 @@ case class Returning[E](frag: Frag, reader: DbCodec[E]): def iterator( fetchSize: Int = 0 )(using con: DbCon, use: Manager): Iterator[E] = + withResultSet(ResultSetIterator(_, frag, reader, con.sqlLogger)) + + private def withResultSet[A](f: ResultSet => A)(using con: DbCon): A = handleQuery(frag.sqlString, frag.params): - Try: - val ps = use(con.connection.prepareStatement(frag.sqlString)) - ps.setFetchSize(fetchSize) - frag.writer.write(ps, 1) - timed: - val hasResults = ps.execute() - if hasResults then - val rs = use(ps.getResultSet()) - ResultSetIterator(rs, frag, reader, con.sqlLogger) - else - throw UnsupportedOperationException( - "No results for RETURNING clause" - ) + Manager: use => + if keyColumns.isEmpty then + val ps = use(con.connection.prepareStatement(frag.sqlString)) + frag.writer.write(ps, 1) + timed: + val hasResults = ps.execute() + if hasResults then + val rs = use(ps.getResultSet) + f(rs) + else + throw UnsupportedOperationException( + "No results for RETURNING clause" + ) + else + val ps = use( + con.connection.prepareStatement(frag.sqlString, keyColumsArr) + ) + frag.writer.write(ps, 1) + timed: + ps.execute() + val rs = use(ps.getGeneratedKeys) + f(rs) end Returning diff --git a/magnum/src/main/scala/com/augustnagro/magnum/Seek.scala b/magnum/src/main/scala/com/augustnagro/magnum/Seek.scala new file mode 100644 index 0000000..983a69f --- /dev/null +++ b/magnum/src/main/scala/com/augustnagro/magnum/Seek.scala @@ -0,0 +1,10 @@ +package com.augustnagro.magnum + +class Seek( + val column: String, + val seekDirection: SeekDir, + val value: Any, + val columnSort: SortOrder, + val nullOrder: NullOrder, + val codec: DbCodec[?] +) diff --git a/magnum/src/main/scala/com/augustnagro/magnum/SeekDir.scala b/magnum/src/main/scala/com/augustnagro/magnum/SeekDir.scala index d5993ba..a94507d 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/SeekDir.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/SeekDir.scala @@ -1,11 +1,7 @@ package com.augustnagro.magnum -trait SeekDir: - def sql: String +trait SeekDir object SeekDir: - object Gt extends SeekDir: - def sql: String = ">" - - object Lt extends SeekDir: - def sql: String = "<" + case object Gt extends SeekDir + case object Lt extends SeekDir diff --git a/magnum/src/main/scala/com/augustnagro/magnum/Sort.scala b/magnum/src/main/scala/com/augustnagro/magnum/Sort.scala index 17fd2a1..3523a3e 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/Sort.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/Sort.scala @@ -1,7 +1,7 @@ package com.augustnagro.magnum -private case class Sort( - column: String, - direction: SortOrder, - nullOrder: NullOrder +class Sort( + val column: String, + val direction: SortOrder, + val nullOrder: NullOrder ) diff --git a/magnum/src/main/scala/com/augustnagro/magnum/SortOrder.scala b/magnum/src/main/scala/com/augustnagro/magnum/SortOrder.scala index 4cbc171..636c1d9 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/SortOrder.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/SortOrder.scala @@ -1,11 +1,8 @@ package com.augustnagro.magnum -trait SortOrder: - def sql: String +trait SortOrder object SortOrder: - object Asc extends SortOrder: - def sql: String = "ASC" - - object Desc extends SortOrder: - def sql: String = "DESC" + case object Default extends SortOrder + case object Asc extends SortOrder + case object Desc extends SortOrder diff --git a/magnum/src/main/scala/com/augustnagro/magnum/Spec.scala b/magnum/src/main/scala/com/augustnagro/magnum/Spec.scala index f13b788..886976d 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/Spec.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/Spec.scala @@ -1,80 +1,48 @@ package com.augustnagro.magnum -import java.sql.PreparedStatement import java.util.StringJoiner class Spec[E] private ( - predicates: List[Frag], - limit: Option[Int], - offset: Option[Int], - sorts: List[Sort] + val prefix: Option[Frag], + val predicates: Vector[Frag], + val limit: Option[Int], + val offset: Option[Long], + val sorts: Vector[Sort], + val seeks: Vector[Seek] ): + def prefix(sql: Frag): Spec[E] = + new Spec(Some(sql), predicates, limit, offset, sorts, seeks) + def where(sql: Frag): Spec[E] = - new Spec(sql :: predicates, limit, offset, sorts) + new Spec(prefix, predicates :+ sql, limit, offset, sorts, seeks) def orderBy( column: String, - direction: SortOrder = SortOrder.Asc, - nullOrder: NullOrder = NullOrder.Empty + direction: SortOrder = SortOrder.Default, + nullOrder: NullOrder = NullOrder.Default ): Spec[E] = val sort = Sort(column, direction, nullOrder) - new Spec(predicates, limit, offset, sort :: sorts) + new Spec(prefix, predicates, limit, offset, sorts :+ sort, seeks) def limit(limit: Int): Spec[E] = - new Spec(predicates, Some(limit), offset, sorts) + new Spec(prefix, predicates, Some(limit), offset, sorts, seeks) - def offset(offset: Int): Spec[E] = - new Spec(predicates, limit, Some(offset), sorts) + def offset(offset: Long): Spec[E] = + new Spec(prefix, predicates, limit, Some(offset), sorts, seeks) def seek[V]( column: String, seekDirection: SeekDir, value: V, columnSort: SortOrder, - nullOrder: NullOrder = NullOrder.Last + nullOrder: NullOrder = NullOrder.Default )(using codec: DbCodec[V]): Spec[E] = - val sort = Sort(column, columnSort, nullOrder) - val pred = - Frag( - s"$column ${seekDirection.sql} ?", - Vector(value), - (ps, pos) => - codec.writeSingle(value, ps, pos) - pos + codec.cols.length - ) - new Spec(pred :: predicates, limit, offset, sort :: sorts) - - def build: Frag = - val whereClause = StringJoiner(" AND ", "WHERE ", "").setEmptyValue("") - val allParams = Vector.newBuilder[Any] - - val validFrags = predicates.reverse.filter(_.sqlString.nonEmpty) - for frag <- validFrags do - whereClause.add("(" + frag.sqlString + ")") - allParams ++= frag.params - - val orderByClause = StringJoiner(", ", "ORDER BY ", "").setEmptyValue("") - for Sort(col, dir, nullOrder) <- sorts.reverse do - orderByClause.add(col + " " + dir.sql + " " + nullOrder.sql) - - val finalSj = StringJoiner(" ") - val whereClauseStr = whereClause.toString - if whereClauseStr.nonEmpty then finalSj.add(whereClauseStr) - val orderByClauseStr = orderByClause.toString - if orderByClauseStr.nonEmpty then finalSj.add(orderByClauseStr) - for l <- limit do finalSj.add("LIMIT " + l) - for o <- offset do finalSj.add("OFFSET " + o) - - val fragWriter: FragWriter = (ps, startingPos) => - validFrags.foldLeft(startingPos)((pos, frag) => - frag.writer.write(ps, pos) - ) + val seek = Seek(column, seekDirection, value, columnSort, nullOrder, codec) + new Spec(prefix, predicates, limit, offset, sorts, seeks :+ seek) - Frag(finalSj.toString, allParams.result(), fragWriter) - end build end Spec object Spec: def apply[E]: Spec[E] = - new Spec(Nil, None, None, Nil) + new Spec(None, Vector.empty, None, None, Vector.empty, Vector.empty) diff --git a/magnum/src/main/scala/com/augustnagro/magnum/SpecImpl.scala b/magnum/src/main/scala/com/augustnagro/magnum/SpecImpl.scala new file mode 100644 index 0000000..14c40af --- /dev/null +++ b/magnum/src/main/scala/com/augustnagro/magnum/SpecImpl.scala @@ -0,0 +1,88 @@ +package com.augustnagro.magnum + +import java.util.StringJoiner + +private trait SpecImpl: + def sortSql(sort: Sort): String = + val dir = sort.direction match + case SortOrder.Default => "" + case SortOrder.Asc => " ASC" + case SortOrder.Desc => " DESC" + case _ => throw UnsupportedOperationException() + val nullOrder = sort.nullOrder match + case NullOrder.Default => "" + case NullOrder.First => " NULLS FIRST" + case NullOrder.Last => " NULLS LAST" + case _ => throw UnsupportedOperationException() + sort.column + dir + nullOrder + + def offsetLimitSql(offset: Option[Long], limit: Option[Int]): Option[String] = + (offset, limit) match + case (Some(o), Some(l)) => Some(s"OFFSET $o LIMIT $l") + case (Some(o), None) => Some(s"OFFSET $o") + case (None, Some(l)) => Some(s"LIMIT $l") + case (None, None) => None + + def seekSql(seek: Seek): String = + val seekDir = seek.seekDirection match + case SeekDir.Gt => ">" + case SeekDir.Lt => "<" + case _ => throw UnsupportedOperationException() + s"${seek.column} $seekDir ?" + + def findAll[E: DbCodec](spec: Spec[E], tableNameSql: String)(using + DbCon + ): Vector[E] = + val whereClause = StringJoiner(" AND ", "WHERE ", "").setEmptyValue("") + + val allParams = Vector.newBuilder[Any] + + val tableNameLiteral = SqlLiteral(tableNameSql) + val prefixFrag = spec.prefix.getOrElse(sql"SELECT * FROM $tableNameLiteral") + allParams ++= prefixFrag.params + + val seekPredicates = spec.seeks.map(seek => + val codec = seek.codec.asInstanceOf[DbCodec[Any]] + Frag( + seekSql(seek), + Vector(seek.value), + (ps, pos) => + codec.writeSingle(seek.value, ps, pos) + pos + codec.cols.length + ) + ) + + val whereFrags = + (spec.predicates ++ seekPredicates).filter(_.sqlString.nonEmpty) + for frag <- whereFrags do + whereClause.add("(" + frag.sqlString + ")") + allParams ++= frag.params + + val seekSorts = + spec.seeks.map(seek => Sort(seek.column, seek.columnSort, seek.nullOrder)) + val orderByClause = + StringJoiner(", ", "ORDER BY ", "").setEmptyValue("") + for sort <- spec.sorts ++ seekSorts do orderByClause.add(sortSql(sort)) + + val finalSj = StringJoiner(" ") + if prefixFrag.sqlString.nonEmpty then finalSj.add(prefixFrag.sqlString) + val whereClauseStr = whereClause.toString + if whereClauseStr.nonEmpty then finalSj.add(whereClauseStr) + val orderByClauseStr = orderByClause.toString + if orderByClauseStr.nonEmpty then finalSj.add(orderByClauseStr) + + for offsetLimit <- offsetLimitSql(spec.offset, spec.limit) do + finalSj.add(offsetLimit) + + val allFrags = prefixFrag +: whereFrags + val fragWriter: FragWriter = (ps, startingPos) => + allFrags.foldLeft(startingPos)((pos, frag) => frag.writer.write(ps, pos)) + + Frag(finalSj.toString, allParams.result(), fragWriter) + .query[E] + .run() + end findAll +end SpecImpl + +private object SpecImpl: + object Default extends SpecImpl diff --git a/magnum/src/main/scala/com/augustnagro/magnum/SqliteDbType.scala b/magnum/src/main/scala/com/augustnagro/magnum/SqliteDbType.scala index d893734..ae51902 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/SqliteDbType.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/SqliteDbType.scala @@ -9,6 +9,17 @@ import scala.util.{Failure, Success, Using} object SqliteDbType extends DbType: + private val specImpl = new SpecImpl: + override def offsetLimitSql( + offset: Option[Long], + limit: Option[Int] + ): Option[String] = + (offset, limit) match + case (Some(o), Some(l)) => Some(s"LIMIT $o, $l") + case (Some(o), None) => Some(s"LIMIT $o, ${Long.MaxValue}") + case (None, Some(l)) => Some(s"LIMIT $l") + case (None, None) => None + def buildRepoDefaults[EC, E, ID]( tableNameSql: String, eElemNames: Seq[String], @@ -43,17 +54,18 @@ object SqliteDbType extends DbType: val insertGenKeys = eElemNamesSql.toArray val countSql = s"SELECT count(*) FROM $tableNameSql" - val countQuery = Frag(countSql).query[Long] + val countQuery = Frag(countSql, Vector.empty, FragWriter.empty).query[Long] val existsByIdSql = s"SELECT 1 FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val findAllSql = s"SELECT * FROM $tableNameSql" - val findAllQuery = Frag(findAllSql).query[E] + val findAllQuery = Frag(findAllSql, Vector.empty, FragWriter.empty).query[E] val findByIdSql = s"SELECT * FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val deleteByIdSql = s"DELETE FROM $tableNameSql WHERE $idName = ${idCodec.queryRepr}" val truncateSql = s"DELETE FROM $tableNameSql" - val truncateUpdate = Frag(truncateSql).update + val truncateUpdate = + Frag(truncateSql, Vector.empty, FragWriter.empty).update val insertSql = s"INSERT INTO $tableNameSql $ecInsertKeys VALUES (${ecCodec.queryRepr})" val updateSql = @@ -76,10 +88,7 @@ object SqliteDbType extends DbType: def findAll(using DbCon): Vector[E] = findAllQuery.run() def findAll(spec: Spec[E])(using DbCon): Vector[E] = - val f = spec.build - Frag(s"SELECT * FROM $tableNameSql ${f.sqlString}", f.params, f.writer) - .query[E] - .run() + specImpl.findAll(spec, tableNameSql) def findById(id: ID)(using DbCon): Option[E] = Frag(findByIdSql, IArray(id), idWriter(id)) diff --git a/magnum/src/main/scala/com/augustnagro/magnum/UUIDCodec.scala b/magnum/src/main/scala/com/augustnagro/magnum/UUIDCodec.scala index 9df8783..e391e46 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/UUIDCodec.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/UUIDCodec.scala @@ -8,8 +8,8 @@ object UUIDCodec: def queryRepr: String = "?" val cols: IArray[Int] = IArray(Types.VARCHAR) def readSingle(rs: ResultSet, pos: Int): UUID = - rs.getString(pos) match - case null => null - case uuidStr => UUID.fromString(uuidStr) + UUID.fromString(rs.getString(pos)) + def readSingleOption(rs: ResultSet, pos: Int): Option[UUID] = + Option(rs.getString(pos)).map(UUID.fromString) def writeSingle(entity: UUID, ps: PreparedStatement, pos: Int): Unit = ps.setString(pos, entity.toString) diff --git a/magnum/src/main/scala/com/augustnagro/magnum/Update.scala b/magnum/src/main/scala/com/augustnagro/magnum/Update.scala index fd007e7..47b8d5c 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/Update.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/Update.scala @@ -4,7 +4,7 @@ import java.util.concurrent.TimeUnit import scala.concurrent.duration.FiniteDuration import scala.util.{Failure, Success, Using} -case class Update(frag: Frag): +class Update private[magnum] (val frag: Frag): /** Exactly like [[java.sql.PreparedStatement]].executeUpdate */ def run()(using con: DbCon): Int = handleQuery(frag.sqlString, frag.params): diff --git a/magnum/src/main/scala/com/augustnagro/magnum/util.scala b/magnum/src/main/scala/com/augustnagro/magnum/util.scala index 575e036..7354905 100644 --- a/magnum/src/main/scala/com/augustnagro/magnum/util.scala +++ b/magnum/src/main/scala/com/augustnagro/magnum/util.scala @@ -4,9 +4,10 @@ import com.augustnagro.magnum.SqlException import java.lang.System.Logger.Level import java.sql.{Connection, PreparedStatement, ResultSet, Statement} +import java.util.StringJoiner import java.util.concurrent.TimeUnit import javax.sql.DataSource -import scala.collection.mutable.ReusableBuilder +import scala.collection.mutable as m import scala.util.{Failure, Success, Try, Using, boundary} import scala.deriving.Mirror import scala.compiletime.{ @@ -79,18 +80,33 @@ private def sqlImpl(sc: Expr[StringContext], args: Expr[Seq[Any]])(using case _ => true } + val flattenedParamExprs = + flattenParamExprs(paramExprs, '{ Vector.newBuilder[Any] }) + val queryExpr = '{ $sc.s($interpolatedVarargs: _*) } val exprParams = Expr.ofSeq(paramExprs) '{ val argValues = $exprParams + val flattenedParams = $flattenedParamExprs val writer: FragWriter = (ps: PreparedStatement, pos: Int) => { ${ sqlWriter('{ ps }, '{ pos }, '{ argValues }, paramExprs, '{ 0 }) } } - Frag($queryExpr, argValues, writer) + Frag($queryExpr, flattenedParams, writer) } end sqlImpl +private def flattenParamExprs( + paramExprs: Seq[Expr[Any]], + res: Expr[m.Builder[Any, Vector[Any]]] +)(using q: Quotes): Expr[Seq[Any]] = + paramExprs match + case '{ $arg: Frag } +: tail => + flattenParamExprs(tail, '{ $res ++= $arg.params }) + case arg +: tail => + flattenParamExprs(tail, '{ $res += $arg }) + case Seq() => '{ $res.result() } + private def sqlWriter( psExpr: Expr[PreparedStatement], posExpr: Expr[Int], diff --git a/magnum/src/test/resources/clickhouse-car.sql b/magnum/src/test/resources/clickhouse-car.sql deleted file mode 100644 index be00be3..0000000 --- a/magnum/src/test/resources/clickhouse-car.sql +++ /dev/null @@ -1,38 +0,0 @@ -drop table if exists car; - -CREATE TABLE car ( - model String NOT NULL, - id UUID NOT NULL, - top_speed Int32 NOT NULL, - created DateTime NOT NULL, - vin Nullable(Int32), - color Enum('Red', 'Green', 'Blue') -) -ENGINE = MergeTree() -ORDER BY created; - -INSERT INTO car (model, id, top_speed, created, vin, color) VALUES -( - 'McLaren Senna', - toUUID('a88a32f1-1e4a-41b9-9fb0-e9a8aba2428a'), - 208, - toDateTime('2023-03-05 02:26:00'), - 123, - 'Red' -), -( - 'Ferrari F8 Tributo', - toUUID('e4895170-5b54-4e3b-b857-b95d45d3550c'), - 212, - toDateTime('2023-03-05 02:27:00'), - 124, - 'Green' -), -( - 'Aston Martin Superleggera', - toUUID('460798da-917d-442f-a987-a7e6528ddf17'), - 211, - toDateTime('2023-03-05 02:28:00'), - null, - 'Blue' -); \ No newline at end of file diff --git a/magnum/src/test/resources/clickhouse-person.sql b/magnum/src/test/resources/clickhouse-person.sql deleted file mode 100644 index e1fe050..0000000 --- a/magnum/src/test/resources/clickhouse-person.sql +++ /dev/null @@ -1,78 +0,0 @@ -drop table if exists person; - -create table person ( - id UUID not null default generateUUIDv4(), - first_name Nullable(String), - last_name String not null, - is_admin Bool not null, - created DateTime not null default now(), - social_id Nullable(UUID) -) -engine = MergeTree() -order by created; - -insert into person values -( - toUUID('3b1bc33b-ecc9-45c7-b866-04c60d31d687'), - 'George', - 'Washington', - true, - toDateTime('2023-03-05 02:26:00'), - toUUID('d06443a6-3efb-46c4-a66a-a80a8a9a5388') -), -( - toUUID('12970806-606d-42ff-bb9c-3187bbd360dd'), - 'Alexander', - 'Hamilton', - true, - toDateTime('2023-03-05 02:27:00'), - toUUID('529b6c6d-7228-4da5-81d7-13b706f78ddb') -), -( - toUUID('834a2bd2-6842-424f-97e0-fe5ed02c3653'), - 'John', - 'Adams', - true, - toDateTime('2023-03-05 02:28:00'), - null -), -( - toUUID('60492bb2-fe02-4d02-9c6d-ae03fa6f2243'), - 'Benjamin', - 'Franklin', - true, - toDateTime('2023-03-05 02:29:00'), - null -), -( - toUUID('2244eef4-b581-4305-824f-efe8f70e6bb7'), - 'John', - 'Jay', - true, - toDateTime('2023-03-05 02:30:00'), - null -), -( - toUUID('fb3c479a-0521-4f06-b6ad-218db867518c'), - 'Thomas', - 'Jefferson', - true, - toDateTime('2023-03-05 02:31:00'), - null -), -( - toUUID('8fefe1d8-20eb-44a0-84da-d328334e1e11'), - 'James', - 'Madison', - true, - toDateTime('2023-03-05 02:32:00'), - null -), -( - toUUID('8a2d842a-0d03-463c-8c34-43b38120f9e4'), - null, - 'Nagro', - false, - toDateTime('2023-03-05 02:33:00'), - null -); \ No newline at end of file diff --git a/magnum/src/test/resources/clickhouse/big-dec.sql b/magnum/src/test/resources/clickhouse/big-dec.sql new file mode 100644 index 0000000..7472139 --- /dev/null +++ b/magnum/src/test/resources/clickhouse/big-dec.sql @@ -0,0 +1,12 @@ +drop table if exists big_dec; + +create table big_dec ( + id Int64 NOT NULL, + my_big_dec Nullable(Int256) +) +ENGINE = MergeTree() +ORDER BY id; + +insert into big_dec values +(1, 123), +(2, null); \ No newline at end of file diff --git a/magnum/src/test/resources/clickhouse/car.sql b/magnum/src/test/resources/clickhouse/car.sql new file mode 100644 index 0000000..97eff01 --- /dev/null +++ b/magnum/src/test/resources/clickhouse/car.sql @@ -0,0 +1,17 @@ +drop table if exists car; + +CREATE TABLE car ( + model String NOT NULL, + id Int64 NOT NULL, + top_speed Int32 NOT NULL, + vin Nullable(Int32), + color Enum('Red', 'Green', 'Blue'), + created DateTime NOT NULL +) +ENGINE = MergeTree() +ORDER BY created; + +INSERT INTO car (model, id, top_speed, vin, color, created) VALUES +('McLaren Senna', 1, 208, 123, 'Red', toDateTime('2024-11-24 22:17:30', 'UTC')), +('Ferrari F8 Tributo', 2, 212, 124, 'Green', toDateTime('2024-11-24 22:17:31', 'UTC')), +('Aston Martin Superleggera', 3, 211, null, 'Blue', toDateTime('2024-11-24 22:17:32', 'UTC')); \ No newline at end of file diff --git a/magnum/src/test/resources/clickhouse/no-id.sql b/magnum/src/test/resources/clickhouse/no-id.sql new file mode 100644 index 0000000..160dc70 --- /dev/null +++ b/magnum/src/test/resources/clickhouse/no-id.sql @@ -0,0 +1,14 @@ +drop table if exists no_id; + +CREATE TABLE no_id ( + created_at DateTime NOT NULL, + user_name String NOT NULL, + user_action String NOT NULL +) +ENGINE = MergeTree() +ORDER BY created_at; + +INSERT INTO no_id VALUES +(timestamp '1997-08-15', 'Josh', 'clicked a button'), +(timestamp '1997-08-16', 'Danny', 'opened a toaster'), +(timestamp '1997-08-17', 'Greg', 'ran some QA tests'); \ No newline at end of file diff --git a/magnum/src/test/resources/clickhouse/person.sql b/magnum/src/test/resources/clickhouse/person.sql new file mode 100644 index 0000000..73fc901 --- /dev/null +++ b/magnum/src/test/resources/clickhouse/person.sql @@ -0,0 +1,22 @@ +drop table if exists person; + +create table person ( + id Int64 not null, + first_name Nullable(String), + last_name String not null, + is_admin Bool not null, + created DateTime not null, + social_id Nullable(UUID) +) +engine = MergeTree() +order by created; + +insert into person values +(1, 'George', 'Washington', true, toDateTime('2023-03-05 02:26:00'), toUUID('d06443a6-3efb-46c4-a66a-a80a8a9a5388')), +(2, 'Alexander', 'Hamilton', true, toDateTime('2023-03-05 02:27:00'), toUUID('529b6c6d-7228-4da5-81d7-13b706f78ddb')), +(3, 'John', 'Adams', true, toDateTime('2023-03-05 02:28:00'), null), +(4, 'Benjamin', 'Franklin', true, toDateTime('2023-03-05 02:29:00'), null), +(5, 'John', 'Jay', true, toDateTime('2023-03-05 02:30:00'), null), +(6, 'Thomas', 'Jefferson', true, toDateTime('2023-03-05 02:31:00'), null), +(7, 'James', 'Madison', true, toDateTime('2023-03-05 02:32:00'), null), +(8, null, 'Nagro', false, toDateTime('2023-03-05 02:33:00'), null); \ No newline at end of file diff --git a/magnum/src/test/resources/h2-car.sql b/magnum/src/test/resources/h2-car.sql deleted file mode 100644 index a67032e..0000000 --- a/magnum/src/test/resources/h2-car.sql +++ /dev/null @@ -1,14 +0,0 @@ -drop table if exists car; - -create table car ( - model varchar(50) not null, - id bigint auto_increment primary key, - top_speed int, - vin int, - color enum('Red', 'Green', 'Blue') -); - -insert into car (model, top_speed, vin, color) values -('McLaren Senna', 208, 123, 'Red'), -('Ferrari F8 Tributo', 212, 124, 'Green'), -('Aston Martin Superleggera', 211, null, 'Blue'); diff --git a/magnum/src/test/resources/h2-person.sql b/magnum/src/test/resources/h2-person.sql deleted file mode 100644 index 3abf147..0000000 --- a/magnum/src/test/resources/h2-person.sql +++ /dev/null @@ -1,20 +0,0 @@ -drop table if exists person cascade; - -create table person ( - id bigint auto_increment primary key, - first_name varchar(50), - last_name varchar(50) not null, - is_admin boolean not null, - created timestamp default current_timestamp, - social_id UUID -); - -insert into person (first_name, last_name, is_admin, created, social_id) values -('George', 'Washington', true, now(), 'd06443a6-3efb-46c4-a66a-a80a8a9a5388'), -('Alexander', 'Hamilton', true, now(), '529b6c6d-7228-4da5-81d7-13b706f78ddb'), -('John', 'Adams', true, now(), null), -('Benjamin', 'Franklin', true, now(), null), -('John', 'Jay', true, now(), null), -('Thomas', 'Jefferson', true, now(), null), -('James', 'Madison', true, now(), null), -(null, 'Nagro', false, now(), null); diff --git a/magnum/src/test/resources/h2/big-dec.sql b/magnum/src/test/resources/h2/big-dec.sql new file mode 100644 index 0000000..7b7b072 --- /dev/null +++ b/magnum/src/test/resources/h2/big-dec.sql @@ -0,0 +1,10 @@ +drop table if exists big_dec cascade; + +create table big_dec ( + id int auto_increment primary key, + my_big_dec numeric +); + +insert into big_dec values +(1, 123), +(2, null); \ No newline at end of file diff --git a/magnum/src/test/resources/h2/car.sql b/magnum/src/test/resources/h2/car.sql new file mode 100644 index 0000000..d1c7eb7 --- /dev/null +++ b/magnum/src/test/resources/h2/car.sql @@ -0,0 +1,15 @@ +drop table if exists car; + +create table car ( + model varchar(50) not null, + id bigint auto_increment primary key, + top_speed int not null, + vin int, + color enum('Red', 'Green', 'Blue'), + created timestamp with time zone not null +); + +insert into car (model, top_speed, vin, color, created) values +('McLaren Senna', 208, 123, 'Red', '2024-11-24T22:17:30.000000000Z'), +('Ferrari F8 Tributo', 212, 124, 'Green', '2024-11-24T22:17:31.000000000Z'), +('Aston Martin Superleggera', 211, null, 'Blue', '2024-11-24T22:17:32.000000000Z'); diff --git a/magnum/src/test/resources/h2/my-user.sql b/magnum/src/test/resources/h2/my-user.sql new file mode 100644 index 0000000..5a6851c --- /dev/null +++ b/magnum/src/test/resources/h2/my-user.sql @@ -0,0 +1,11 @@ +drop table if exists my_user cascade; + +create table my_user ( + first_name text not null, + id bigint auto_increment primary key +); + +insert into my_user (first_name) values +('George'), +('Alexander'), +('John'); diff --git a/magnum/src/test/resources/h2/no-id.sql b/magnum/src/test/resources/h2/no-id.sql new file mode 100644 index 0000000..731a28d --- /dev/null +++ b/magnum/src/test/resources/h2/no-id.sql @@ -0,0 +1,12 @@ +drop table if exists no_id; + +create table no_id ( + created_at timestamp with time zone default now() not null, + user_name varchar not null, + user_action varchar not null +); + +insert into no_id values +(timestamp '1997-08-15', 'Josh', 'clicked a button'), +(timestamp '1997-08-16', 'Danny', 'opened a toaster'), +(timestamp '1997-08-17', 'Greg', 'ran some QA tests'); diff --git a/magnum/src/test/resources/h2/person.sql b/magnum/src/test/resources/h2/person.sql new file mode 100644 index 0000000..33a6afb --- /dev/null +++ b/magnum/src/test/resources/h2/person.sql @@ -0,0 +1,20 @@ +drop table if exists person cascade; + +create table person ( + id bigint primary key, + first_name varchar(50), + last_name varchar(50) not null, + is_admin boolean not null, + created timestamp with time zone, + social_id UUID +); + +insert into person (id, first_name, last_name, is_admin, created, social_id) values +(1, 'George', 'Washington', true, now(), 'd06443a6-3efb-46c4-a66a-a80a8a9a5388'), +(2, 'Alexander', 'Hamilton', true, now(), '529b6c6d-7228-4da5-81d7-13b706f78ddb'), +(3, 'John', 'Adams', true, now(), null), +(4, 'Benjamin', 'Franklin', true, now(), null), +(5, 'John', 'Jay', true, now(), null), +(6, 'Thomas', 'Jefferson', true, now(), null), +(7, 'James', 'Madison', true, now(), null), +(8, null, 'Nagro', false, now(), null); diff --git a/magnum/src/test/resources/mysql-car.sql b/magnum/src/test/resources/mysql-car.sql deleted file mode 100644 index a67032e..0000000 --- a/magnum/src/test/resources/mysql-car.sql +++ /dev/null @@ -1,14 +0,0 @@ -drop table if exists car; - -create table car ( - model varchar(50) not null, - id bigint auto_increment primary key, - top_speed int, - vin int, - color enum('Red', 'Green', 'Blue') -); - -insert into car (model, top_speed, vin, color) values -('McLaren Senna', 208, 123, 'Red'), -('Ferrari F8 Tributo', 212, 124, 'Green'), -('Aston Martin Superleggera', 211, null, 'Blue'); diff --git a/magnum/src/test/resources/mysql-person.sql b/magnum/src/test/resources/mysql-person.sql deleted file mode 100644 index a469a91..0000000 --- a/magnum/src/test/resources/mysql-person.sql +++ /dev/null @@ -1,20 +0,0 @@ -drop table if exists person cascade; - -create table person ( - id bigint auto_increment primary key, - first_name varchar(50), - last_name varchar(50) not null, - is_admin boolean not null, - created timestamp default current_timestamp, - social_id varchar(36) -); - -insert into person (first_name, last_name, is_admin, created, social_id) values -('George', 'Washington', true, now(), 'd06443a6-3efb-46c4-a66a-a80a8a9a5388'), -('Alexander', 'Hamilton', true, now(), '529b6c6d-7228-4da5-81d7-13b706f78ddb'), -('John', 'Adams', true, now(), null), -('Benjamin', 'Franklin', true, now(), null), -('John', 'Jay', true, now(), null), -('Thomas', 'Jefferson', true, now(), null), -('James', 'Madison', true, now(), null), -(null, 'Nagro', false, now(), null); diff --git a/magnum/src/test/resources/pg-bigdec.sql b/magnum/src/test/resources/mysql/big-dec.sql similarity index 100% rename from magnum/src/test/resources/pg-bigdec.sql rename to magnum/src/test/resources/mysql/big-dec.sql diff --git a/magnum/src/test/resources/mysql/car.sql b/magnum/src/test/resources/mysql/car.sql new file mode 100644 index 0000000..3b64140 --- /dev/null +++ b/magnum/src/test/resources/mysql/car.sql @@ -0,0 +1,15 @@ +drop table if exists car; + +create table car ( + model varchar(50) not null, + id bigint primary key, + top_speed int not null, + vin int, + color enum('Red', 'Green', 'Blue'), + created datetime not null +); + +insert into car (model, id, top_speed, vin, color, created) values +('McLaren Senna', 1, 208, 123, 'Red', '2024-11-24 22:17:30'), +('Ferrari F8 Tributo', 2, 212, 124, 'Green', '2024-11-24 22:17:31'), +('Aston Martin Superleggera', 3, 211, null, 'Blue', '2024-11-24 22:17:32'); diff --git a/magnum/src/test/resources/mysql/my-user.sql b/magnum/src/test/resources/mysql/my-user.sql new file mode 100644 index 0000000..f5fdfa0 --- /dev/null +++ b/magnum/src/test/resources/mysql/my-user.sql @@ -0,0 +1,11 @@ +drop table if exists my_user cascade; + +create table my_user ( + first_name varchar(200) not null, + id bigint auto_increment primary key +); + +insert into my_user (first_name) values +('George'), +('Alexander'), +('John'); diff --git a/magnum/src/test/resources/mysql/no-id.sql b/magnum/src/test/resources/mysql/no-id.sql new file mode 100644 index 0000000..175aff1 --- /dev/null +++ b/magnum/src/test/resources/mysql/no-id.sql @@ -0,0 +1,12 @@ +drop table if exists no_id; + +create table no_id ( + created_at datetime not null default now(), + user_name varchar(200) not null, + user_action varchar(200) not null +); + +insert into no_id values +('1997-08-15', 'Josh', 'clicked a button'), +('1997-08-16', 'Danny', 'opened a toaster'), +('1997-08-17', 'Greg', 'ran some QA tests'); diff --git a/magnum/src/test/resources/mysql/person.sql b/magnum/src/test/resources/mysql/person.sql new file mode 100644 index 0000000..60a5420 --- /dev/null +++ b/magnum/src/test/resources/mysql/person.sql @@ -0,0 +1,20 @@ +drop table if exists person cascade; + +create table person ( + id bigint primary key, + first_name varchar(50), + last_name varchar(50) not null, + is_admin boolean not null, + created datetime not null, + social_id varchar(36) +); + +insert into person (id, first_name, last_name, is_admin, created, social_id) values +(1, 'George', 'Washington', true, now(), 'd06443a6-3efb-46c4-a66a-a80a8a9a5388'), +(2, 'Alexander', 'Hamilton', true, now(), '529b6c6d-7228-4da5-81d7-13b706f78ddb'), +(3, 'John', 'Adams', true, now(), null), +(4, 'Benjamin', 'Franklin', true, now(), null), +(5, 'John', 'Jay', true, now(), null), +(6, 'Thomas', 'Jefferson', true, now(), null), +(7, 'James', 'Madison', true, now(), null), +(8, null, 'Nagro', false, now(), null); diff --git a/magnum/src/test/resources/pg-car.sql b/magnum/src/test/resources/pg-car.sql deleted file mode 100644 index 60066a4..0000000 --- a/magnum/src/test/resources/pg-car.sql +++ /dev/null @@ -1,17 +0,0 @@ -drop table if exists car; -drop type if exists color; - -create type Color as enum ('Red', 'Green', 'Blue'); - -create table car ( - model varchar(50) not null, - id bigint primary key generated always as identity, - top_speed int not null, - vin int, - color Color not null -); - -insert into car (model, top_speed, vin, color) values -('McLaren Senna', 208, 123, 'Red'), -('Ferrari F8 Tributo', 212, 124, 'Green'), -('Aston Martin Superleggera', 211, null, 'Blue'); diff --git a/magnum/src/test/resources/pg-person.sql b/magnum/src/test/resources/pg-person.sql deleted file mode 100644 index cdcbc12..0000000 --- a/magnum/src/test/resources/pg-person.sql +++ /dev/null @@ -1,20 +0,0 @@ -drop table if exists person cascade; - -create table person ( - id bigint primary key generated always as identity, - first_name varchar(50), - last_name varchar(50) not null, - is_admin boolean not null, - created timestamptz not null default now(), - social_id UUID -); - -insert into person (first_name, last_name, is_admin, created, social_id) values -('George', 'Washington', true, now(), 'd06443a6-3efb-46c4-a66a-a80a8a9a5388'), -('Alexander', 'Hamilton', true, now(), '529b6c6d-7228-4da5-81d7-13b706f78ddb'), -('John', 'Adams', true, now(), null), -('Benjamin', 'Franklin', true, now(), null), -('John', 'Jay', true, now(), null), -('Thomas', 'Jefferson', true, now(), null), -('James', 'Madison', true, now(), null), -(null, 'Nagro', false, timestamp '1997-08-12', null); diff --git a/magnum/src/test/resources/pg/big-dec.sql b/magnum/src/test/resources/pg/big-dec.sql new file mode 100644 index 0000000..c09a8fc --- /dev/null +++ b/magnum/src/test/resources/pg/big-dec.sql @@ -0,0 +1,10 @@ +drop table if exists big_dec cascade; + +create table big_dec ( + id int primary key, + my_big_dec numeric +); + +insert into big_dec values +(1, 123), +(2, null); \ No newline at end of file diff --git a/magnum/src/test/resources/pg/car.sql b/magnum/src/test/resources/pg/car.sql new file mode 100644 index 0000000..9a41162 --- /dev/null +++ b/magnum/src/test/resources/pg/car.sql @@ -0,0 +1,15 @@ +DROP TABLE IF EXISTS car; + +CREATE TABLE car ( + model VARCHAR(50) NOT NULL, + id bigint PRIMARY KEY, + top_speed INT NOT NULL, + vin INT, + color TEXT NOT NULL CHECK (color IN ('Red', 'Green', 'Blue')), + created TIMESTAMP WITH TIME ZONE NOT NULL +); + +INSERT INTO car (model, id, top_speed, vin, color, created) VALUES +('McLaren Senna', 1, 208, 123, 'Red', '2024-11-24T22:17:30.000000000Z'::timestamptz), +('Ferrari F8 Tributo', 2, 212, 124, 'Green', '2024-11-24T22:17:31.000000000Z'::timestamptz), +('Aston Martin Superleggera', 3, 211, null, 'Blue', '2024-11-24T22:17:32.000000000Z'::timestamptz); diff --git a/magnum/src/test/resources/pg/my-user.sql b/magnum/src/test/resources/pg/my-user.sql new file mode 100644 index 0000000..a04ed42 --- /dev/null +++ b/magnum/src/test/resources/pg/my-user.sql @@ -0,0 +1,11 @@ +drop table if exists my_user cascade; + +create table my_user ( + first_name text not null, + id bigint primary key generated always as identity +); + +insert into my_user (first_name) values +('George'), +('Alexander'), +('John'); diff --git a/magnum/src/test/resources/pg-no-id.sql b/magnum/src/test/resources/pg/no-id.sql similarity index 100% rename from magnum/src/test/resources/pg-no-id.sql rename to magnum/src/test/resources/pg/no-id.sql diff --git a/magnum/src/test/resources/pg/person.sql b/magnum/src/test/resources/pg/person.sql new file mode 100644 index 0000000..2929af6 --- /dev/null +++ b/magnum/src/test/resources/pg/person.sql @@ -0,0 +1,20 @@ +drop table if exists person cascade; + +create table person ( + id bigint primary key, + first_name varchar(50), + last_name varchar(50) not null, + is_admin boolean not null, + created timestamptz not null, + social_id UUID +); + +insert into person (id, first_name, last_name, is_admin, created, social_id) values +(1, 'George', 'Washington', true, now(), 'd06443a6-3efb-46c4-a66a-a80a8a9a5388'), +(2, 'Alexander', 'Hamilton', true, now(), '529b6c6d-7228-4da5-81d7-13b706f78ddb'), +(3, 'John', 'Adams', true, now(), null), +(4, 'Benjamin', 'Franklin', true, now(), null), +(5, 'John', 'Jay', true, now(), null), +(6, 'Thomas', 'Jefferson', true, now(), null), +(7, 'James', 'Madison', true, now(), null), +(8, null, 'Nagro', false, timestamp '1997-08-12', null); diff --git a/magnum/src/test/scala/ClickHouseTests.scala b/magnum/src/test/scala/ClickHouseTests.scala index f56a011..ee73b7a 100644 --- a/magnum/src/test/scala/ClickHouseTests.scala +++ b/magnum/src/test/scala/ClickHouseTests.scala @@ -1,300 +1,18 @@ import com.augustnagro.magnum.* -import com.dimafeng.testcontainers.munit.fixtures.TestContainersFixtures -import com.dimafeng.testcontainers.{ - ClickHouseContainer, - ContainerDef, - JdbcDatabaseContainer -} -import munit.{AnyFixture, FunSuite, Location, TestOptions} import com.clickhouse.jdbc.ClickHouseDataSource +import com.dimafeng.testcontainers.ClickHouseContainer +import com.dimafeng.testcontainers.munit.fixtures.TestContainersFixtures +import munit.{AnyFixture, FunSuite, Location} import org.testcontainers.utility.DockerImageName +import shared.* import java.nio.file.{Files, Path} -import java.sql.Connection -import java.time.{OffsetDateTime, ZoneOffset} -import java.util.{Properties, UUID} -import javax.sql.DataSource +import java.util.UUID import scala.util.Using class ClickHouseTests extends FunSuite, TestContainersFixtures: - /* - Immutable Repo Tests - */ - - enum Color derives DbCodec: - case Red, Green, Blue - - @Table(ClickhouseDbType, SqlNameMapper.CamelToSnakeCase) - case class Car( - model: String, - @Id id: UUID, - topSpeed: Int, - created: OffsetDateTime, - @SqlName("vin") vinNumber: Option[Int], - color: Color - ) derives DbCodec - - val carRepo = ImmutableRepo[Car, UUID] - val car = TableInfo[Car, Car, UUID] - - val allCars = Vector( - Car( - "McLaren Senna", - UUID.fromString("a88a32f1-1e4a-41b9-9fb0-e9a8aba2428a"), - 208, - OffsetDateTime.of(2023, 3, 5, 2, 26, 0, 0, ZoneOffset.UTC), - Some(123), - Color.Red - ), - Car( - "Ferrari F8 Tributo", - UUID.fromString("e4895170-5b54-4e3b-b857-b95d45d3550c"), - 212, - OffsetDateTime.of(2023, 3, 5, 2, 27, 0, 0, ZoneOffset.UTC), - Some(124), - Color.Green - ), - Car( - "Aston Martin Superleggera", - UUID.fromString("460798da-917d-442f-a987-a7e6528ddf17"), - 211, - OffsetDateTime.of(2023, 3, 5, 2, 28, 0, 0, ZoneOffset.UTC), - None, - Color.Blue - ) - ) - - test("count"): - connect(ds()): - assertEquals(carRepo.count, 3L) - - test("existsById"): - connect(ds()): - assert(carRepo.existsById(allCars.head.id)) - assert(!carRepo.existsById(UUID.randomUUID)) - - test("findAll"): - val cars = connect(ds()): - carRepo.findAll - assertEquals(cars, allCars) - - test("findAll spec"): - connect(ds()): - val topSpeed = 211 - val spec = Spec[Car] - .where(sql"${car.topSpeed} > $topSpeed") - assertEquals(carRepo.findAll(spec), Vector(allCars(1))) - - test("findById"): - connect(ds()): - assertEquals(carRepo.findById(allCars.last.id).get, allCars.last) - assertEquals(carRepo.findById(UUID.randomUUID), None) - - test("findAllByIds"): - intercept[UnsupportedOperationException]: - connect(ds()): - assertEquals( - carRepo.findAllById(Vector(allCars(0).id, allCars(1).id)).size, - 2 - ) - - test("repeatable read transaction"): - transact(ds(), withRepeatableRead): - assertEquals(carRepo.count, 3L) - - private def withRepeatableRead(con: Connection): Unit = - con.setTransactionIsolation(Connection.TRANSACTION_REPEATABLE_READ) - - test("select query"): - connect(ds()): - val minSpeed: Int = 210 - val query = - sql"select ${car.all} from $car where ${car.topSpeed} > $minSpeed" - .query[Car] - assertNoDiff( - query.frag.sqlString, - "select model, id, top_speed, created, vin, color from car where top_speed > ?" - ) - assertEquals(query.frag.params, Vector(minSpeed)) - assertEquals(query.run(), allCars.tail) - - test("select query with aliasing"): - connect(ds()): - val minSpeed = 210 - val cAlias = car.alias("c") - val query = - sql"select ${cAlias.all} from $cAlias where ${cAlias.topSpeed} > $minSpeed" - .query[Car] - assertNoDiff( - query.frag.sqlString, - "select c.model, c.id, c.top_speed, c.created, c.vin, c.color from car c where c.top_speed > ?" - ) - assertEquals(query.frag.params, Vector(minSpeed)) - assertEquals(query.run(), allCars.tail) - - test("select via option"): - connect(ds()): - val vin = Some(124) - val cars = - sql"select * from car where vin = $vin" - .query[Car] - .run() - assertEquals(cars, allCars.filter(_.vinNumber == vin)) - - test("tuple select"): - connect(ds()): - val tuples = sql"select model, color from car where id = ${allCars(1).id}" - .query[(String, Color)] - .run() - assertEquals(tuples, Vector(allCars(1).model -> allCars(1).color)) - - test("reads null int as None and not Some(0)"): - connect(ds()): - assertEquals(carRepo.findAll.last.vinNumber, None) - - /* - Repo Tests - */ - - @Table(ClickhouseDbType, SqlNameMapper.CamelToSnakeCase) - case class Person( - id: UUID, - firstName: Option[String], - lastName: String, - isAdmin: Boolean, - created: OffsetDateTime, - socialId: Option[UUID] - ) derives DbCodec - - val personRepo = Repo[Person, Person, UUID] - val person = TableInfo[Person, Person, UUID] - - test("delete"): - connect(ds()): - val p = personRepo.findAll.head - personRepo.delete(p) - assertEquals(personRepo.findById(p.id), None) - - test("delete invalid"): - connect(ds()): - val p = personRepo.findAll.head.copy(id = UUID.randomUUID) - personRepo.delete(p) - assertEquals(8L, personRepo.count) - - test("deleteById"): - connect(ds()): - val p = personRepo.findAll.head - personRepo.deleteById(p.id) - personRepo.deleteById(UUID.randomUUID) - assertEquals(personRepo.count, 7L) - - test("truncate"): - connect(ds()): - personRepo.truncate() - assertEquals(personRepo.count, 0L) - - test("insert"): - connect(ds()): - personRepo.insert( - Person( - id = UUID.randomUUID, - firstName = Some("John"), - lastName = "Smith", - isAdmin = false, - created = OffsetDateTime.now, - socialId = Some(UUID.randomUUID()) - ) - ) - personRepo.insert( - Person( - id = UUID.randomUUID, - firstName = None, - lastName = "Prince", - isAdmin = true, - created = OffsetDateTime.now, - socialId = None - ) - ) - assertEquals(personRepo.count, 10L) - assert(personRepo.findAll.exists(_.lastName == "Prince")) - - test("insertAll"): - connect(ds()): - personRepo.insertAll( - Vector( - Person( - id = UUID.randomUUID, - firstName = Some("John"), - lastName = "Smith", - isAdmin = false, - created = OffsetDateTime.now, - socialId = Some(UUID.randomUUID()) - ), - Person( - id = UUID.randomUUID, - firstName = None, - lastName = "Prince", - isAdmin = true, - created = OffsetDateTime.now, - socialId = None - ) - ) - ) - assertEquals(personRepo.count, 10L) - - test("insertReturning"): - connect(ds()): - val id = UUID.randomUUID - val person = personRepo.insertReturning( - Person( - id = id, - firstName = Some("John"), - lastName = "Smith", - isAdmin = false, - created = OffsetDateTime.now, - socialId = Some(UUID.randomUUID()) - ) - ) - assertEquals(personRepo.count, 9L) - assertEquals(personRepo.findById(id).get.firstName, person.firstName) - - test("insertAllReturning"): - connect(ds()): - val ps = Vector( - Person( - id = UUID.randomUUID, - firstName = Some("John"), - lastName = "Smith", - isAdmin = false, - created = OffsetDateTime.now, - socialId = Some(UUID.randomUUID()) - ), - Person( - id = UUID.randomUUID, - firstName = None, - lastName = "Prince", - isAdmin = true, - created = OffsetDateTime.now, - socialId = None - ) - ) - val people = personRepo.insertAllReturning(ps) - assertEquals(people, ps) - assertEquals(personRepo.count, 10L) - - test("insert invalid"): - intercept[SqlException]: - connect(ds()): - val invalidP = Person( - id = UUID.randomUUID, - firstName = None, - lastName = null, - isAdmin = false, - created = OffsetDateTime.now, - socialId = None - ) - personRepo.insert(invalidP) + sharedTests(this, ClickhouseDbType, xa) test("only allows EC =:= E"): intercept[IllegalArgumentException]: @@ -303,93 +21,6 @@ class ClickHouseTests extends FunSuite, TestContainersFixtures: case class User(id: UUID, name: String) derives DbCodec val repo = Repo[UserCreator, User, UUID] - test("update"): - intercept[UnsupportedOperationException]: - connect(ds()): - val p = personRepo.findAll.head - val updated = p.copy(firstName = None) - personRepo.update(updated) - assertEquals(personRepo.findById(p.id).get, updated) - - test("insertAll"): - connect(ds()): - val newPeople = Vector( - Person( - id = UUID.randomUUID, - firstName = Some("John"), - lastName = "Smith", - isAdmin = false, - created = OffsetDateTime.now, - socialId = Some(UUID.randomUUID()) - ), - Person( - id = UUID.randomUUID, - firstName = None, - lastName = "Prince", - isAdmin = true, - created = OffsetDateTime.now, - socialId = None - ) - ) - personRepo.insertAll(newPeople) - assertEquals(personRepo.count, 10L) - assert(personRepo.findAll.exists(_.lastName == "Smith")) - - test("updateAll"): - intercept[UnsupportedOperationException]: - connect(ds()): - val allPeople = personRepo.findAll - val newPeople = Vector( - allPeople(0).copy(lastName = "Peterson"), - allPeople(1).copy(lastName = "Moreno") - ) - personRepo.updateAll(newPeople) - assertEquals(personRepo.findById(allPeople(0).id).get, newPeople(0)) - assertEquals(personRepo.findById(allPeople(1).id).get, newPeople(1)) - - test("transact"): - val count = transact(ds())(personRepo.count) - assertEquals(count, 8L) - - test("custom insert"): - connect(ds()): - val p = Person( - id = UUID.randomUUID, - firstName = Some("John"), - lastName = "Smith", - isAdmin = false, - created = OffsetDateTime.now, - socialId = Some(UUID.randomUUID()) - ) - val update = - sql"insert into $person ${person.insertColumns} values ($p)".update - assertNoDiff( - update.frag.sqlString, - "insert into person (id, first_name, last_name, is_admin, created, social_id) values (?, ?, ?, ?, ?, ?)" - ) - val rowsInserted = update.run() - assertEquals(rowsInserted, 1) - assertEquals(personRepo.count, 9L) - val fetched = personRepo.findById(p.id).get - assertEquals(fetched.firstName, p.firstName) - assertEquals(fetched.lastName, p.lastName) - assertEquals(fetched.isAdmin, p.isAdmin) - - test("custom update"): - connect(ds()): - val p = personRepo.findAll.head - val newIsAdmin = true - val update = - sql"update $person set ${person.isAdmin} = $newIsAdmin where ${person.id} = ${p.id}".update - - assertNoDiff( - update.frag.sqlString, - "update person set is_admin = ? where id = ?" - ) - val rowsUpdated = update.run() - assertEquals(rowsUpdated, 1) - assertEquals(personRepo.findById(p.id).get.isAdmin, true) - val clickHouseContainer = ForAllContainerFixture( ClickHouseContainer .Def(dockerImageName = @@ -398,38 +29,24 @@ class ClickHouseTests extends FunSuite, TestContainersFixtures: .createContainer() ) - test("embed Frag into Frag"): - def findPersonCnt(filter: Frag, limit: Long = 1)(using DbCon): Int = - val offsetFrag = sql"OFFSET 0" - val limitFrag = sql"LIMIT $limit" - sql"SELECT count(*) FROM person WHERE $filter $limitFrag $offsetFrag" - .query[Int] - .run() - .head - val isAdminFrag = sql"is_admin = true" - connect(ds()): - val johnCnt = findPersonCnt(sql"$isAdminFrag AND first_name = 'John'", 2) - assertEquals(johnCnt, 2) - override def munitFixtures: Seq[AnyFixture[_]] = super.munitFixtures :+ clickHouseContainer - def ds(): DataSource = + def xa(): Transactor = val clickHouse = clickHouseContainer() val ds = ClickHouseDataSource(clickHouse.jdbcUrl) - val carSql = Files.readString( - Path.of(getClass.getResource("/clickhouse-car.sql").toURI) - ) - val personSql = Files.readString( - Path.of(getClass.getResource("/clickhouse-person.sql").toURI) - ) + val tableDDLs = Vector( + "clickhouse/car.sql", + "clickhouse/no-id.sql", + "clickhouse/person.sql", + "clickhouse/big-dec.sql" + ).map(p => Files.readString(Path.of(getClass.getResource(p).toURI))) Using .Manager(use => val con = use(ds.getConnection) val stmt = use(con.createStatement) - stmt.execute(carSql) - stmt.execute(personSql) + for ddl <- tableDDLs do stmt.execute(ddl) ) .get - ds + Transactor(ds) end ClickHouseTests diff --git a/magnum/src/test/scala/H2Tests.scala b/magnum/src/test/scala/H2Tests.scala index abe77dd..cda514b 100644 --- a/magnum/src/test/scala/H2Tests.scala +++ b/magnum/src/test/scala/H2Tests.scala @@ -1,414 +1,35 @@ import com.augustnagro.magnum.* import munit.FunSuite import org.h2.jdbcx.JdbcDataSource +import shared.* import java.nio.file.{Files, Path} -import java.sql.{Connection, DriverManager} -import java.time.OffsetDateTime -import java.util.UUID -import javax.sql.DataSource -import scala.util.Properties.propOrNone import scala.util.Using import scala.util.Using.Manager class H2Tests extends FunSuite: - enum Color derives DbCodec: - case Red, Green, Blue - - @Table(H2DbType, SqlNameMapper.CamelToSnakeCase) - case class Car( - model: String, - @Id id: Long, - topSpeed: Int, - @SqlName("vin") vinNumber: Option[Int], - color: Color - ) derives DbCodec - - val carRepo = ImmutableRepo[Car, Long] - val car = TableInfo[Car, Car, Long] - - val allCars = Vector( - Car("McLaren Senna", 1L, 208, Some(123), Color.Red), - Car("Ferrari F8 Tributo", 2L, 212, Some(124), Color.Green), - Car("Aston Martin Superleggera", 3L, 211, None, Color.Blue) - ) - - test("count"): - val count = connect(ds()): - carRepo.count - assertEquals(count, 3L) - - test("existsById"): - connect(ds()): - assert(carRepo.existsById(3L)) - assert(!carRepo.existsById(4L)) - - test("findAll"): - connect(ds()): - assertEquals(carRepo.findAll, allCars) - - test("findAll spec"): - connect(ds()): - val topSpeed = 211 - val spec = Spec[Car] - .where(sql"${car.topSpeed} > $topSpeed") - assertEquals(carRepo.findAll(spec), Vector(allCars(1))) - - test("findById"): - connect(ds()): - assertEquals(carRepo.findById(3L).get, allCars.last) - assertEquals(carRepo.findById(4L), None) - - test("findAllByIds"): - connect(ds()): - assertEquals( - carRepo.findAllById(Vector(1L, 3L)).map(_.id), - Vector(1L, 3L) - ) - - test("repeatable read transaction"): - transact(ds(), withRepeatableRead): - assertEquals(carRepo.count, 3L) - - private def withRepeatableRead(con: Connection): Unit = - con.setTransactionIsolation(Connection.TRANSACTION_REPEATABLE_READ) - - test("select query"): - connect(ds()): - val minSpeed: Int = 210 - val query = - sql"select ${car.all} from $car where ${car.topSpeed} > $minSpeed" - .query[Car] - assertNoDiff( - query.frag.sqlString, - "select model, id, top_speed, vin, color from car where top_speed > ?" - ) - assertEquals(query.frag.params, Vector(minSpeed)) - assertEquals(query.run(), allCars.tail) - - test("select query with aliasing"): - connect(ds()): - val minSpeed = 210 - val cAlias = car.alias("c") - val query = - sql"select ${cAlias.all} from $cAlias where ${cAlias.topSpeed} > $minSpeed" - .query[Car] - assertNoDiff( - query.frag.sqlString, - "select c.model, c.id, c.top_speed, c.vin, c.color from car c where c.top_speed > ?" - ) - assertEquals(query.frag.params, Vector(minSpeed)) - assertEquals(query.run(), allCars.tail) - - test("select via option"): - connect(ds()): - val vin = Some(124) - val cars = - sql"select * from car where vin = $vin" - .query[Car] - .run() - assertEquals(cars, allCars.filter(_.vinNumber == vin)) - - test("tuple select"): - connect(ds()): - val tuples = sql"select model, color from car where id = 2" - .query[(String, Color)] - .run() - assertEquals(tuples, Vector(allCars(1).model -> allCars(1).color)) - - test("reads null int as None and not Some(0)"): - connect(ds()): - assertEquals(carRepo.findById(3L).get.vinNumber, None) - - case class PersonCreator( - firstName: Option[String], - lastName: String, - isAdmin: Boolean, - socialId: Option[UUID] - ) derives DbCodec - - @Table(H2DbType, SqlNameMapper.CamelToSnakeCase) - case class Person( - id: Long, - firstName: Option[String], - lastName: String, - isAdmin: Boolean, - created: OffsetDateTime, - socialId: Option[UUID] - ) derives DbCodec - - val personRepo = Repo[PersonCreator, Person, Long] - val person = TableInfo[PersonCreator, Person, Long] - - test("delete"): - connect(ds()): - val p = personRepo.findById(1L).get - personRepo.delete(p) - assertEquals(personRepo.findById(1L), None) - - test("delete invalid"): - connect(ds()): - personRepo.delete(Person(23L, None, "", false, OffsetDateTime.now, None)) - assertEquals(8L, personRepo.count) - - test("deleteById"): - connect(ds()): - personRepo.deleteById(1L) - personRepo.deleteById(2L) - personRepo.deleteById(1L) - assertEquals(personRepo.findAll.size, 6) - - test("deleteAll"): - connect(ds()): - val p1 = personRepo.findById(1L).get - val p2 = p1.copy(id = 2L) - val p3 = p1.copy(id = 99L) - assertEquals( - personRepo.deleteAll(Vector(p1, p2, p3)), - BatchUpdateResult.Success(2) - ) - assertEquals(6L, personRepo.count) - - test("deleteAllById"): - connect(ds()): - assertEquals( - personRepo.deleteAllById(Vector(1L, 2L, 1L)), - BatchUpdateResult.Success(2) - ) - assertEquals(6L, personRepo.count) - - test("truncate"): - connect(ds()): - personRepo.truncate() - assertEquals(personRepo.count, 0L) - - test("insert"): - connect(ds()): - personRepo.insert( - PersonCreator( - firstName = Some("John"), - lastName = "Smith", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - ) - personRepo.insert( - PersonCreator( - firstName = None, - lastName = "Prince", - isAdmin = true, - socialId = None - ) - ) - assertEquals(personRepo.count, 10L) - assertEquals(personRepo.findById(9L).get.lastName, "Smith") - - test("insertReturning"): - connect(ds()): - val person = personRepo.insertReturning( - PersonCreator( - firstName = Some("John"), - lastName = "Smith", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - ) - assertEquals(person.id, 9L) - assertEquals(person.lastName, "Smith") - - test("insertAllReturning"): - connect(ds()): - val newPc = Vector( - PersonCreator( - firstName = Some("Chandler"), - lastName = "Johnsored", - isAdmin = true, - socialId = Some(UUID.randomUUID()) - ), - PersonCreator( - firstName = None, - lastName = "Odysseus", - isAdmin = false, - socialId = None - ), - PersonCreator( - firstName = Some("Jorge"), - lastName = "Masvidal", - isAdmin = true, - socialId = None - ) - ) - val people = personRepo.insertAllReturning(newPc) - assertEquals(personRepo.count, 11L) - assertEquals(people.size, 3) - assertEquals(people.last.lastName, newPc.last.lastName) - - test("insert invalid"): - intercept[SqlException]: - connect(ds()): - val invalidP = PersonCreator(None, null, false, None) - personRepo.insert(invalidP) - - test("update"): - connect(ds()): - val p = personRepo.findById(1L).get - val updated = p.copy(firstName = None, isAdmin = false) - personRepo.update(updated) - assertEquals(personRepo.findById(1L).get, updated) - - test("update invalid"): - intercept[SqlException]: - connect(ds()): - val p = personRepo.findById(1L).get - val updated = p.copy(lastName = null) - personRepo.update(updated) - - test("insertAll"): - connect(ds()): - val newPeople = Vector( - PersonCreator( - firstName = Some("Chandler"), - lastName = "Johnsored", - isAdmin = true, - socialId = Some(UUID.randomUUID()) - ), - PersonCreator( - firstName = None, - lastName = "Odysseus", - isAdmin = false, - socialId = None - ), - PersonCreator( - firstName = Some("Jorge"), - lastName = "Masvidal", - isAdmin = true, - socialId = None - ) - ) - personRepo.insertAll(newPeople) - assertEquals(personRepo.count, 11L) - assertEquals( - personRepo.findById(11L).get.lastName, - newPeople.last.lastName - ) - - test("updateAll"): - connect(ds()): - val newPeople = Vector( - personRepo.findById(1L).get.copy(lastName = "Peterson"), - personRepo.findById(2L).get.copy(lastName = "Moreno") - ) - assertEquals( - personRepo.updateAll(newPeople), - BatchUpdateResult.Success(2) - ) - assertEquals(personRepo.findById(1L).get, newPeople(0)) - assertEquals(personRepo.findById(2L).get, newPeople(1)) - - test("transact"): - val count = transact(ds()): - val p = PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - personRepo.insert(p) - personRepo.count - assertEquals(count, 9L) - - test("transact failed"): - val dataSource = ds() - val p = PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = false, - socialId = None - ) - try - transact(dataSource): - personRepo.insert(p) - throw RuntimeException() - fail("should not reach") - catch - case _: Exception => - transact(dataSource): - assertEquals(personRepo.count, 8L) - - test("custom insert"): - connect(ds()): - val p = PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = false, - socialId = None - ) - val update = - sql"insert into $person ${person.insertColumns} values ($p)".update - assertNoDiff( - update.frag.sqlString, - "insert into person (first_name, last_name, is_admin, social_id) values (?, ?, ?, ?)" - ) - val rowsInserted = update.run() - assertEquals(rowsInserted, 1) - assertEquals(personRepo.count, 9L) - val fetched = personRepo.findAll.last - assertEquals(fetched.firstName, p.firstName) - assertEquals(fetched.lastName, p.lastName) - assertEquals(fetched.isAdmin, p.isAdmin) - - test("custom update"): - connect(ds()): - val p = personRepo.insertReturning( - PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - ) - val newIsAdmin = true - val update = - sql"update $person set ${person.isAdmin} = $newIsAdmin where ${person.id} = ${p.id}".update - assertNoDiff( - update.frag.sqlString, - "update person set is_admin = ? where id = ?" - ) - val rowsUpdated = update.run() - assertEquals(rowsUpdated, 1) - assertEquals(personRepo.findById(p.id).get.isAdmin, true) - - test("embed Frag into Frag"): - def findPersonCnt(filter: Frag, limit: Long = 1)(using DbCon): Int = - val offsetFrag = sql"OFFSET 0" - val limitFrag = sql"LIMIT $limit" - sql"SELECT count(*) FROM person WHERE $filter $limitFrag $offsetFrag" - .query[Int] - .run() - .head - val isAdminFrag = sql"is_admin = true" - connect(ds()): - val johnCnt = findPersonCnt(sql"$isAdminFrag AND first_name = 'John'", 2) - assertEquals(johnCnt, 2) + sharedTests(this, H2DbType, xa) lazy val h2DbPath = Files.createTempDirectory(null).toAbsolutePath - def ds(): DataSource = + def xa(): Transactor = val ds = JdbcDataSource() ds.setURL("jdbc:h2:" + h2DbPath) ds.setUser("sa") ds.setPassword("") - val carSql = - Files.readString(Path.of(getClass.getResource("/h2-car.sql").toURI)) - val personSql = - Files.readString(Path.of(getClass.getResource("/h2-person.sql").toURI)) + val tableDDLs = Vector( + "/h2/car.sql", + "/h2/person.sql", + "/h2/my-user.sql", + "/h2/no-id.sql", + "/h2/big-dec.sql" + ).map(p => Files.readString(Path.of(getClass.getResource(p).toURI))) Manager(use => val con = use(ds.getConnection) val stmt = use(con.createStatement) - stmt.execute(carSql) - stmt.execute(personSql) + for ddl <- tableDDLs do stmt.execute(ddl) ) - ds + Transactor(ds) end H2Tests diff --git a/magnum/src/test/scala/MySqlTests.scala b/magnum/src/test/scala/MySqlTests.scala index ab06e84..676c6f6 100644 --- a/magnum/src/test/scala/MySqlTests.scala +++ b/magnum/src/test/scala/MySqlTests.scala @@ -1,417 +1,19 @@ import com.augustnagro.magnum.* +import com.augustnagro.magnum.UUIDCodec.VarCharUUIDCodec +import com.dimafeng.testcontainers.MySQLContainer import com.dimafeng.testcontainers.munit.fixtures.TestContainersFixtures -import com.dimafeng.testcontainers.{ - ContainerDef, - JdbcDatabaseContainer, - MySQLContainer -} import com.mysql.cj.jdbc.MysqlDataSource -import munit.{AnyFixture, FunSuite, Location, TestOptions} +import munit.{AnyFixture, FunSuite, Location} import org.testcontainers.utility.DockerImageName +import shared.* import java.nio.file.{Files, Path} -import java.sql.Connection -import java.time.OffsetDateTime -import java.util.UUID -import javax.sql.DataSource import scala.util.Using import scala.util.Using.Manager -import com.augustnagro.magnum.UUIDCodec.VarCharUUIDCodec class MySqlTests extends FunSuite, TestContainersFixtures: - enum Color derives DbCodec: - case Red, Green, Blue - - @Table(MySqlDbType, SqlNameMapper.CamelToSnakeCase) - case class Car( - model: String, - @Id id: Long, - topSpeed: Int, - @SqlName("vin") vinNumber: Option[Int], - color: Color - ) derives DbCodec - - val carRepo = ImmutableRepo[Car, Long] - val car = TableInfo[Car, Car, Long] - - val allCars = Vector( - Car("McLaren Senna", 1L, 208, Some(123), Color.Red), - Car("Ferrari F8 Tributo", 2L, 212, Some(124), Color.Green), - Car("Aston Martin Superleggera", 3L, 211, None, Color.Blue) - ) - - test("count"): - connect(ds()): - assertEquals(carRepo.count, 3L) - - test("existsById"): - connect(ds()): - assert(carRepo.existsById(3L)) - assert(!carRepo.existsById(4L)) - - test("findAll"): - connect(ds()): - assertEquals(carRepo.findAll, allCars) - - test("findAll spec"): - connect(ds()): - val topSpeed = 211 - val spec = Spec[Car] - .where(sql"${car.topSpeed} > $topSpeed") - assertEquals(carRepo.findAll(spec), Vector(allCars(1))) - - test("findAll spec with ordering"): - connect(ds()): - val topSpeed = 211 - val spec = Spec[Car] - .where(sql"${car.topSpeed} > $topSpeed") - .orderBy(s"id", SortOrder.Asc) - assertEquals(carRepo.findAll(spec), Vector(allCars(1))) - - test("findById"): - connect(ds()): - assertEquals(carRepo.findById(3L).get, allCars.last) - assertEquals(carRepo.findById(4L), None) - - test("findAllByIds"): - intercept[UnsupportedOperationException]: - connect(ds()): - assertEquals( - carRepo.findAllById(Vector(1L, 3L)).map(_.id), - Vector(1L, 3L) - ) - - test("repeatable read transaction"): - transact(ds(), withRepeatableRead): - assertEquals(carRepo.count, 3L) - - private def withRepeatableRead(con: Connection): Unit = - con.setTransactionIsolation(Connection.TRANSACTION_REPEATABLE_READ) - - test("select query"): - connect(ds()): - val minSpeed: Int = 210 - val query = - sql"select ${car.all} from $car where ${car.topSpeed} > $minSpeed" - .query[Car] - assertNoDiff( - query.frag.sqlString, - "select model, id, top_speed, vin, color from car where top_speed > ?" - ) - assertEquals(query.frag.params, Vector(minSpeed)) - assertEquals(query.run(), allCars.tail) - - test("select query with aliasing"): - connect(ds()): - val minSpeed = 210 - val cAlias = car.alias("c") - val query = - sql"select ${cAlias.all} from $cAlias where ${cAlias.topSpeed} > $minSpeed" - .query[Car] - assertNoDiff( - query.frag.sqlString, - "select c.model, c.id, c.top_speed, c.vin, c.color from car c where c.top_speed > ?" - ) - assertEquals(query.frag.params, Vector(minSpeed)) - assertEquals(query.run(), allCars.tail) - - test("select via option"): - connect(ds()): - val vin = Some(124) - val cars = - sql"select * from car where vin = $vin" - .query[Car] - .run() - assertEquals(cars, allCars.filter(_.vinNumber == vin)) - - test("tuple select"): - connect(ds()): - val tuples = sql"select model, color from car where id = 2" - .query[(String, Color)] - .run() - assertEquals(tuples, Vector(allCars(1).model -> allCars(1).color)) - - test("reads null int as None and not Some(0)"): - connect(ds()): - assertEquals(carRepo.findById(3L).get.vinNumber, None) - - /* - Repo Tests - */ - case class PersonCreator( - firstName: Option[String], - lastName: String, - isAdmin: Boolean, - socialId: Option[UUID] - ) derives DbCodec - - @Table(MySqlDbType, SqlNameMapper.CamelToSnakeCase) - case class Person( - id: Long, - firstName: Option[String], - lastName: String, - isAdmin: Boolean, - created: OffsetDateTime, - socialId: Option[UUID] - ) derives DbCodec - - val personRepo = Repo[PersonCreator, Person, Long] - val person = TableInfo[PersonCreator, Person, Long] - - test("delete"): - connect(ds()): - val p = personRepo.findById(1L).get - personRepo.delete(p) - assertEquals(personRepo.findById(1L), None) - - test("delete invalid"): - connect(ds()): - personRepo.delete(Person(23L, None, "", false, OffsetDateTime.now, None)) - assertEquals(8L, personRepo.count) - - test("deleteById"): - connect(ds()): - personRepo.deleteById(1L) - personRepo.deleteById(2L) - personRepo.deleteById(1L) - assertEquals(personRepo.findAll.size, 6) - - test("deleteAll"): - connect(ds()): - val p1 = personRepo.findById(1L).get - val p2 = p1.copy(id = 2L) - val p3 = p1.copy(id = 99L) - assertEquals( - personRepo.deleteAll(Vector(p1, p2, p3)), - BatchUpdateResult.Success(2) - ) - assertEquals(6L, personRepo.count) - - test("deleteAllById"): - connect(ds()): - assertEquals( - personRepo.deleteAllById(Vector(1L, 2L, 1L)), - BatchUpdateResult.Success(2) - ) - assertEquals(6L, personRepo.count) - - test("truncate"): - connect(ds()): - personRepo.truncate() - assertEquals(personRepo.count, 0L) - - test("insert"): - connect(ds()): - personRepo.insert( - PersonCreator( - firstName = Some("John"), - lastName = "Smith", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - ) - personRepo.insert( - PersonCreator( - firstName = None, - lastName = "Prince", - isAdmin = true, - socialId = None - ) - ) - - assertEquals(personRepo.count, 10L) - assertEquals(personRepo.findById(9L).get.lastName, "Smith") - - test("insertReturning"): - connect(ds()): - val person = personRepo.insertReturning( - PersonCreator( - firstName = Some("John"), - lastName = "Smith", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - ) - assertEquals(person.id, 9L) - assertEquals(person.lastName, "Smith") - - test("insertAllReturning"): - connect(ds()): - val newPc = Vector( - PersonCreator( - firstName = Some("Chandler"), - lastName = "Johnsored", - isAdmin = true, - socialId = Some(UUID.randomUUID()) - ), - PersonCreator( - firstName = None, - lastName = "Odysseus", - isAdmin = false, - socialId = None - ), - PersonCreator( - firstName = Some("Jorge"), - lastName = "Masvidal", - isAdmin = true, - socialId = None - ) - ) - val people = personRepo.insertAllReturning(newPc) - assertEquals(personRepo.count, 11L) - assertEquals(people.size, 3) - assertEquals(people.last.lastName, newPc.last.lastName) - - test("insert invalid"): - intercept[SqlException]: - connect(ds()): - val invalidP = PersonCreator(None, null, false, None) - personRepo.insert(invalidP) - - test("update"): - connect(ds()): - val p = personRepo.findById(1L).get - val updated = p.copy(firstName = None, isAdmin = false) - personRepo.update(updated) - assertEquals(personRepo.findById(1L).get, updated) - - test("update invalid"): - intercept[SqlException]: - connect(ds()): - val p = personRepo.findById(1L).get - val updated = p.copy(lastName = null) - personRepo.update(updated) - - test("insertAll"): - connect(ds()): - val newPeople = Vector( - PersonCreator( - firstName = Some("Chandler"), - lastName = "Johnsored", - isAdmin = true, - socialId = Some(UUID.randomUUID()) - ), - PersonCreator( - firstName = None, - lastName = "Odysseus", - isAdmin = false, - socialId = None - ), - PersonCreator( - firstName = Some("Jorge"), - lastName = "Masvidal", - isAdmin = true, - socialId = None - ) - ) - - personRepo.insertAll(newPeople) - assertEquals(personRepo.count, 11L) - assertEquals( - personRepo.findById(11L).get.lastName, - newPeople.last.lastName - ) - - test("updateAll"): - connect(ds()): - val newPeople = Vector( - personRepo.findById(1L).get.copy(lastName = "Peterson"), - personRepo.findById(2L).get.copy(lastName = "Moreno") - ) - assertEquals( - personRepo.updateAll(newPeople), - BatchUpdateResult.Success(2) - ) - assertEquals(personRepo.findById(1L).get, newPeople(0)) - assertEquals(personRepo.findById(2L).get, newPeople(1)) - - test("transact"): - val count = transact(ds()): - val p = PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - personRepo.insert(p) - personRepo.count - assertEquals(count, 9L) - - test("transact failed"): - val dataSource = ds() - val p = PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - try - transact(dataSource): - personRepo.insert(p) - throw RuntimeException() - fail("should not reach") - catch - case _: Exception => - transact(dataSource): - assertEquals(personRepo.count, 8L) - - test("custom insert"): - connect(ds()): - val p = PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - val update = - sql"insert into $person ${person.insertColumns} values ($p)".update - assertNoDiff( - update.frag.sqlString, - "insert into person (first_name, last_name, is_admin, social_id) values (?, ?, ?, ?)" - ) - val rowsInserted = update.run() - assertEquals(rowsInserted, 1) - assertEquals(personRepo.count, 9L) - val fetched = personRepo.findAll.last - assertEquals(fetched.firstName, p.firstName) - assertEquals(fetched.lastName, p.lastName) - assertEquals(fetched.isAdmin, p.isAdmin) - - test("custom update"): - connect(ds()): - val p = personRepo.insertReturning( - PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - ) - val newIsAdmin = true - val update = - sql"update $person set ${person.isAdmin} = $newIsAdmin where ${person.id} = ${p.id}".update - - assertNoDiff( - update.frag.sqlString, - "update person set is_admin = ? where id = ?" - ) - val rowsUpdated = update.run() - assertEquals(rowsUpdated, 1) - assertEquals(personRepo.findById(p.id).get.isAdmin, true) - - test("embed Frag into Frag"): - def findPersonCnt(filter: Frag, limit: Long = 1)(using DbCon): Int = - val offsetFrag = sql"OFFSET 0" - val limitFrag = sql"LIMIT $limit" - sql"SELECT count(*) FROM person WHERE $filter $limitFrag $offsetFrag" - .query[Int] - .run() - .head - val isAdminFrag = sql"is_admin = true" - connect(ds()): - val johnCnt = findPersonCnt(sql"$isAdminFrag AND first_name = 'John'", 2) - assertEquals(johnCnt, 2) + sharedTests(this, MySqlDbType, xa) val mySqlContainer = ForAllContainerFixture( MySQLContainer @@ -422,21 +24,26 @@ class MySqlTests extends FunSuite, TestContainersFixtures: override def munitFixtures: Seq[AnyFixture[_]] = super.munitFixtures :+ mySqlContainer - def ds(): DataSource = + def xa(): Transactor = val mySql = mySqlContainer() val ds = MysqlDataSource() ds.setURL(mySql.jdbcUrl) ds.setUser(mySql.username) ds.setPassword(mySql.password) ds.setAllowMultiQueries(true) - val carSql = - Files.readString(Path.of(getClass.getResource("/mysql-car.sql").toURI)) - val personSql = - Files.readString(Path.of(getClass.getResource("/mysql-person.sql").toURI)) + ds.setServerTimezone("UTC") + val tableDDLs = Vector( + "/mysql/car.sql", + "/mysql/person.sql", + "/mysql/my-user.sql", + "/mysql/no-id.sql", + "/mysql/big-dec.sql" + ).map(p => Files.readString(Path.of(getClass.getResource(p).toURI))) Manager(use => val con = use(ds.getConnection) - use(con.prepareStatement(carSql)).execute() - use(con.prepareStatement(personSql)).execute() + val stmt = use(con.createStatement()) + for ddl <- tableDDLs do stmt.execute(ddl) ).get - ds + Transactor(ds) + end xa end MySqlTests diff --git a/magnum/src/test/scala/OracleTests.scala b/magnum/src/test/scala/OracleTests.scala index 3215667..ccd71b2 100644 --- a/magnum/src/test/scala/OracleTests.scala +++ b/magnum/src/test/scala/OracleTests.scala @@ -3,403 +3,19 @@ import com.augustnagro.magnum.UUIDCodec.VarCharUUIDCodec import com.dimafeng.testcontainers.OracleContainer import com.dimafeng.testcontainers.munit.fixtures.TestContainersFixtures import munit.{AnyFixture, FunSuite} -import org.testcontainers.utility.DockerImageName import oracle.jdbc.datasource.impl.OracleDataSource +import org.testcontainers.utility.DockerImageName +import shared.* -import java.nio.file.{Files, Path} -import java.sql.Connection -import java.time.OffsetDateTime -import java.util.UUID -import javax.sql.DataSource +import java.sql.Statement import scala.util.Using class OracleTests extends FunSuite, TestContainersFixtures: - /* - Immutable Repo Tests - */ - - enum Color derives DbCodec: - case Red, Green, Blue - - @Table(OracleDbType, SqlNameMapper.CamelToSnakeCase) - case class Car( - model: String, - @Id id: Long, - topSpeed: Int, - @SqlName("vin") vinNumber: Option[Int], - color: Color - ) derives DbCodec - - val carRepo = ImmutableRepo[Car, Long] - val car = TableInfo[Car, Car, Long] - - val allCars = Vector( - Car("McLaren Senna", 1L, 208, Some(123), Color.Red), - Car("Ferrari F8 Tributo", 2L, 212, Some(124), Color.Green), - Car("Aston Martin Superleggera", 3L, 211, None, Color.Blue) - ) - - test("count"): - connect(ds()): - assertEquals(carRepo.count, 3L) - - test("existsById"): - connect(ds()): - assert(carRepo.existsById(3L)) - assert(!carRepo.existsById(4L)) - - test("findAll"): - val cars = connect(ds()): - carRepo.findAll - assertEquals(cars, allCars) - - test("findAll spec"): - connect(ds()): - val topSpeed = 211 - val spec = Spec[Car] - .where(sql"${car.topSpeed} > $topSpeed") - assertEquals(carRepo.findAll(spec), Vector(allCars(1))) - - test("findById"): - connect(ds()): - assertEquals(carRepo.findById(3L).get, allCars.last) - assertEquals(carRepo.findById(4L), None) - - test("findAllByIds"): - intercept[UnsupportedOperationException]: - connect(ds()): - assertEquals( - carRepo.findAllById(Vector(1L, 3L)).map(_.id), - Vector(1L, 3L) - ) - - test("serializable transaction"): - transact(ds(), withSerializable): - assertEquals(carRepo.count, 3L) - - private def withSerializable(con: Connection): Unit = - con.setTransactionIsolation(Connection.TRANSACTION_SERIALIZABLE) - - test("select query"): - connect(ds()): - val minSpeed: Int = 210 - val query = - sql"select ${car.all} from $car where ${car.topSpeed} > $minSpeed" - .query[Car] - assertNoDiff( - query.frag.sqlString, - "select model, id, top_speed, vin, color from car where top_speed > ?" - ) - assertEquals(query.frag.params, Vector(minSpeed)) - assertEquals(query.run(), allCars.tail) - - test("select query with aliasing"): - connect(ds()): - val minSpeed = 210 - val cAlias = car.alias("c") - val query = - sql"select ${cAlias.all} from $cAlias where ${cAlias.topSpeed} > $minSpeed" - .query[Car] - assertNoDiff( - query.frag.sqlString, - "select c.model, c.id, c.top_speed, c.vin, c.color from car c where c.top_speed > ?" - ) - assertEquals(query.frag.params, Vector(minSpeed)) - assertEquals(query.run(), allCars.tail) - - test("select via option"): - connect(ds()): - val vin = Some(124) - val cars = - sql"select * from car where vin = $vin" - .query[Car] - .run() - assertEquals(cars, allCars.filter(_.vinNumber == vin)) + given DbCodec[Boolean] = + DbCodec[String].biMap(_ == "Y", b => if b then "Y" else "N") - test("tuple select"): - connect(ds()): - val tuples = sql"select model, color from car where id = 2" - .query[(String, Color)] - .run() - assertEquals(tuples, Vector(allCars(1).model -> allCars(1).color)) - - test("reads null int as None and not Some(0)"): - connect(ds()): - assertEquals(carRepo.findById(3L).get.vinNumber, None) - - /* - Repo Tests - */ - case class PersonCreator( - firstName: Option[String], - lastName: String, - isAdmin: String, - socialId: Option[UUID] - ) derives DbCodec - - @Table(OracleDbType, SqlNameMapper.CamelToSnakeCase) - case class Person( - id: Long, - firstName: Option[String], - lastName: String, - isAdmin: String, - created: OffsetDateTime, - socialId: Option[UUID] - ) derives DbCodec - - val personRepo = Repo[PersonCreator, Person, Long] - val person = TableInfo[PersonCreator, Person, Long] - - test("delete"): - connect(ds()): - val p = personRepo.findById(1L).get - personRepo.delete(p) - assertEquals(personRepo.findById(1L), None) - - test("delete invalid"): - connect(ds()): - personRepo.delete(Person(23L, None, "", "N", OffsetDateTime.now, None)) - assertEquals(8L, personRepo.count) - - test("deleteById"): - connect(ds()): - personRepo.deleteById(1L) - personRepo.deleteById(2L) - personRepo.deleteById(1L) - assertEquals(personRepo.findAll.size, 6) - - test("deleteAll"): - connect(ds()): - val p1 = personRepo.findById(1L).get - val p2 = p1.copy(id = 2L) - val p3 = p1.copy(id = 99L) - assertEquals( - personRepo.deleteAll(Vector(p1, p2, p3)), - BatchUpdateResult.Success(2) - ) - assertEquals(6L, personRepo.count) - - test("deleteAllById"): - connect(ds()): - assertEquals( - personRepo.deleteAllById(Vector(1L, 2L, 1L)), - BatchUpdateResult.Success(2) - ) - assertEquals(6L, personRepo.count) - - test("truncate"): - connect(ds()): - personRepo.truncate() - assertEquals(personRepo.count, 0L) - - test("insert"): - connect(ds()): - personRepo.insert( - PersonCreator( - firstName = Some("John"), - lastName = "Smith", - isAdmin = "N", - socialId = Some(UUID.randomUUID()) - ) - ) - personRepo.insert( - PersonCreator( - firstName = None, - lastName = "Prince", - isAdmin = "Y", - socialId = None - ) - ) - assertEquals(personRepo.count, 10L) - assertEquals(personRepo.findById(9L).get.lastName, "Smith") - - test("insertReturning"): - connect(ds()): - val person = personRepo.insertReturning( - PersonCreator( - firstName = Some("John"), - lastName = "Smith", - isAdmin = "N", - socialId = Some(UUID.randomUUID()) - ) - ) - assertEquals(person.id, 9L) - assertEquals(person.lastName, "Smith") - - test("insertAllReturning"): - connect(ds()): - val newPc = Vector( - PersonCreator( - firstName = Some("Chandler"), - lastName = "Johnsored", - isAdmin = "Y", - socialId = Some(UUID.randomUUID()) - ), - PersonCreator( - firstName = None, - lastName = "Odysseus", - isAdmin = "N", - socialId = None - ), - PersonCreator( - firstName = Some("Jorge"), - lastName = "Masvidal", - isAdmin = "Y", - socialId = None - ) - ) - val people = personRepo.insertAllReturning(newPc) - assertEquals(personRepo.count, 11L) - assertEquals(people.size, 3) - assertEquals(people.last.lastName, newPc.last.lastName) - - test("insert invalid"): - intercept[SqlException]: - connect(ds()): - val invalidP = PersonCreator(None, null, "N", None) - personRepo.insert(invalidP) - - test("update"): - connect(ds()): - val p = personRepo.findById(1L).get - val updated = p.copy(firstName = None) - personRepo.update(updated) - assertEquals(personRepo.findById(1L).get, updated) - - test("update invalid"): - intercept[SqlException]: - connect(ds()): - val p = personRepo.findById(1L).get - val updated = p.copy(lastName = null) - personRepo.update(updated) - - test("insertAll"): - connect(ds()): - val newPeople = Vector( - PersonCreator( - firstName = Some("Chandler"), - lastName = "Johnsored", - isAdmin = "Y", - socialId = Some(UUID.randomUUID()) - ), - PersonCreator( - firstName = None, - lastName = "Odysseus", - isAdmin = "N", - socialId = None - ), - PersonCreator( - firstName = Some("Jorge"), - lastName = "Masvidal", - isAdmin = "Y", - socialId = None - ) - ) - personRepo.insertAll(newPeople) - assertEquals(personRepo.count, 11L) - assertEquals( - personRepo.findById(11L).get.lastName, - newPeople.last.lastName - ) - - test("updateAll"): - connect(ds()): - val newPeople = Vector( - personRepo.findById(1L).get.copy(lastName = "Peterson"), - personRepo.findById(2L).get.copy(lastName = "Moreno") - ) - assertEquals( - personRepo.updateAll(newPeople), - BatchUpdateResult.Success(2) - ) - assertEquals(personRepo.findById(1L).get, newPeople(0)) - assertEquals(personRepo.findById(2L).get, newPeople(1)) - - test("transact"): - val count = transact(ds()): - val p = PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = "N", - socialId = Some(UUID.randomUUID()) - ) - personRepo.insert(p) - personRepo.count - assertEquals(count, 9L) - - test("transact failed"): - val dataSource = ds() - val p = PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = "N", - socialId = Some(UUID.randomUUID()) - ) - try - transact(dataSource): - personRepo.insert(p) - throw RuntimeException() - fail("should not reach") - catch - case _: Exception => - transact(dataSource): - assertEquals(personRepo.count, 8L) - - test("custom insert"): - connect(ds()): - val p = PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = "N", - socialId = Some(UUID.randomUUID()) - ) - val update = - sql"insert into $person ${person.insertColumns} values ($p)".update - assertNoDiff( - update.frag.sqlString, - "insert into person (first_name, last_name, is_admin, social_id) values (?, ?, ?, ?)" - ) - val rowsInserted = update.run() - assertEquals(rowsInserted, 1) - assertEquals(personRepo.count, 9L) - val fetched = personRepo.findAll.find(_.firstName == p.firstName).get - assertEquals(fetched.lastName, p.lastName) - assertEquals(fetched.isAdmin, p.isAdmin) - - test("custom update"): - connect(ds()): - val p = personRepo.insertReturning( - PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = "N", - socialId = Some(UUID.randomUUID()) - ) - ) - val newIsAdmin = "Y" - val update = - sql"update $person set ${person.isAdmin} = $newIsAdmin where ${person.id} = ${p.id}".update - assertNoDiff( - update.frag.sqlString, - "update person set is_admin = ? where id = ?" - ) - val rowsUpdated = update.run() - assertEquals(rowsUpdated, 1) - assertEquals(personRepo.findById(p.id).get.isAdmin, "Y") - - test("embed Frag into Frag"): - def findPersonCnt(filter: Frag, limit: Long = 1)(using DbCon): Int = - val offsetFrag = sql"OFFSET 0 ROWS" - val limitFrag = sql"FETCH NEXT $limit ROWS ONLY" - sql"SELECT count(*) FROM person WHERE $filter $offsetFrag $limitFrag" - .query[Int] - .run() - .head - val isAdminFrag = sql"is_admin = 'Y'" - connect(ds()): - val johnCnt = findPersonCnt(sql"$isAdminFrag AND first_name = 'John'", 2) - assertEquals(johnCnt, 2) + sharedTests(this, OracleDbType, xa) val oracleContainer = ForAllContainerFixture( OracleContainer @@ -414,7 +30,7 @@ class OracleTests extends FunSuite, TestContainersFixtures: override def munitFixtures: Seq[AnyFixture[_]] = super.munitFixtures :+ oracleContainer - def ds(): DataSource = + def xa(): Transactor = val oracle = oracleContainer() val ds = OracleDataSource() ds.setURL(oracle.jdbcUrl) @@ -431,70 +47,118 @@ class OracleTests extends FunSuite, TestContainersFixtures: stmt.execute( """create table car ( | model varchar2(50) not null, - | id number generated always as identity, - | top_speed number, + | id number primary key, + | top_speed number not null, | vin number, - | color varchar2(50) not null check (color in ('Red', 'Green', 'Blue')) + | color varchar2(50) not null check (color in ('Red', 'Green', 'Blue')), + | created timestamp not null |)""".stripMargin ) stmt.execute( - """insert into car (model, top_speed, vin, color) - |values ('McLaren Senna', 208, 123, 'Red')""".stripMargin + """insert into car (model, id, top_speed, vin, color, created) + |values ('McLaren Senna', 1, 208, 123, 'Red', timestamp '2024-11-24 22:17:30')""".stripMargin ) stmt.execute( - """insert into car (model, top_speed, vin, color) - |values ('Ferrari F8 Tributo', 212, 124, 'Green')""".stripMargin + """insert into car (model, id, top_speed, vin, color, created) + |values ('Ferrari F8 Tributo', 2, 212, 124, 'Green', timestamp '2024-11-24 22:17:31')""".stripMargin ) stmt.execute( - """insert into car (model, top_speed, vin, color) - |values ('Aston Martin Superleggera', 211, null, 'Blue')""".stripMargin + """insert into car (model, id, top_speed, vin, color, created) + |values ('Aston Martin Superleggera', 3, 211, null, 'Blue', timestamp '2024-11-24 22:17:32')""".stripMargin ) try stmt.execute("drop table person") catch case _ => () stmt.execute( """create table person ( - | id number generated always as identity, + | id number primary key, | first_name varchar2(50), | last_name varchar2(50) not null, | is_admin varchar2(1) not null, - | created timestamp default current_timestamp, + | created timestamp not null, | social_id varchar2(36) |)""".stripMargin ) stmt.execute( - """insert into person (first_name, last_name, is_admin, created, social_id) values - |('George', 'Washington', 'Y', current_timestamp, 'd06443a6-3efb-46c4-a66a-a80a8a9a5388')""".stripMargin + """insert into person (id, first_name, last_name, is_admin, created, social_id) values + |(1, 'George', 'Washington', 'Y', current_timestamp, 'd06443a6-3efb-46c4-a66a-a80a8a9a5388')""".stripMargin + ) + stmt.execute( + """insert into person (id, first_name, last_name, is_admin, created, social_id) values + |(2, 'Alexander', 'Hamilton', 'Y', current_timestamp, '529b6c6d-7228-4da5-81d7-13b706f78ddb')""".stripMargin + ) + stmt.execute( + """insert into person (id, first_name, last_name, is_admin, created, social_id) values + |(3, 'John', 'Adams', 'Y', current_timestamp, null)""".stripMargin + ) + stmt.execute( + """insert into person (id, first_name, last_name, is_admin, created, social_id) values + |(4, 'Benjamin', 'Franklin', 'Y', current_timestamp, null)""".stripMargin ) stmt.execute( - """insert into person (first_name, last_name, is_admin, created, social_id) values - |('Alexander', 'Hamilton', 'Y', current_timestamp, '529b6c6d-7228-4da5-81d7-13b706f78ddb')""".stripMargin + """insert into person (id, first_name, last_name, is_admin, created, social_id) values + |(5, 'John', 'Jay', 'Y', current_timestamp, null)""".stripMargin ) stmt.execute( - """insert into person (first_name, last_name, is_admin, created, social_id) values - |('John', 'Adams', 'Y', current_timestamp, null)""".stripMargin + """insert into person (id, first_name, last_name, is_admin, created, social_id) values + |(6, 'Thomas', 'Jefferson', 'Y', current_timestamp, null)""".stripMargin ) stmt.execute( - """insert into person (first_name, last_name, is_admin, created, social_id) values - |('Benjamin', 'Franklin', 'Y', current_timestamp, null)""".stripMargin + """insert into person (id, first_name, last_name, is_admin, created, social_id) values + |(7, 'James', 'Madison', 'Y', current_timestamp, null)""".stripMargin ) stmt.execute( - """insert into person (first_name, last_name, is_admin, created, social_id) values - |('John', 'Jay', 'Y', current_timestamp, null)""".stripMargin + """insert into person (id, first_name, last_name, is_admin, created, social_id) values + |(8, null, 'Nagro', 'N', current_timestamp, null)""".stripMargin ) + try stmt.execute("drop table my_user") + catch case _ => () + stmt.execute( + """create table my_user ( + | first_name varchar2(200) not null, + | id number generated always as identity, + | primary key (id) + |) + |""".stripMargin + ) + stmt.execute("""insert into my_user (first_name) values ('George')""") + stmt.execute( + """insert into my_user (first_name) values ('Alexander')""" + ) + stmt.execute("""insert into my_user (first_name) values ('John')""") + try stmt.execute("drop table no_id") + catch case _ => () stmt.execute( - """insert into person (first_name, last_name, is_admin, created, social_id) values - |('Thomas', 'Jefferson', 'Y', current_timestamp, null)""".stripMargin + """create table no_id ( + | created_at timestamp not null, + | user_name varchar2(200) not null, + | user_action varchar2(200) not null + |) + |""".stripMargin ) stmt.execute( - """insert into person (first_name, last_name, is_admin, created, social_id) values - |('James', 'Madison', 'Y', current_timestamp, null)""".stripMargin + """insert into no_id (created_at, user_name, user_action) values + |(timestamp '1997-08-15 00:00:00', 'Josh', 'clicked a button')""".stripMargin ) stmt.execute( - """insert into person (first_name, last_name, is_admin, created, social_id) values - |(null, 'Nagro', 'N', current_timestamp, null)""".stripMargin + """insert into no_id (created_at, user_name, user_action) values + |(timestamp '1997-08-16 00:00:00', 'Danny', 'opened a toaster')""".stripMargin + ) + stmt.execute( + """insert into no_id (created_at, user_name, user_action) values + |(timestamp '1997-08-17 00:00:00', 'Greg', 'ran some QA tests')""".stripMargin + ) + try stmt.execute("drop table big_dec") + catch case _ => () + stmt.execute( + """create table big_dec ( + | id number primary key, + | my_big_dec numeric + |)""".stripMargin ) + stmt.execute("insert into big_dec (id, my_big_dec) values (1, 123)") + stmt.execute("insert into big_dec (id, my_big_dec) values (2, null)") ) .get - ds - end ds + Transactor(ds) + end xa end OracleTests diff --git a/magnum/src/test/scala/PgTests.scala b/magnum/src/test/scala/PgTests.scala index b17816b..1dffb8d 100644 --- a/magnum/src/test/scala/PgTests.scala +++ b/magnum/src/test/scala/PgTests.scala @@ -1,491 +1,18 @@ import com.augustnagro.magnum.* +import com.dimafeng.testcontainers.PostgreSQLContainer import com.dimafeng.testcontainers.munit.fixtures.TestContainersFixtures -import com.dimafeng.testcontainers.{ - ContainerDef, - JdbcDatabaseContainer, - PostgreSQLContainer -} -import munit.{AnyFixture, FunSuite, Location, TestOptions} +import munit.{AnyFixture, FunSuite, Location} import org.postgresql.ds.PGSimpleDataSource import org.testcontainers.utility.DockerImageName +import shared.* import java.nio.file.{Files, Path} -import java.sql.Connection -import java.time.OffsetDateTime -import java.util.UUID -import javax.sql.DataSource import scala.util.Using import scala.util.Using.Manager class PgTests extends FunSuite, TestContainersFixtures: - /* - Immutable Repo Tests - */ - - enum Color derives DbCodec: - case Red, Green, Blue - - @Table(PostgresDbType, SqlNameMapper.CamelToSnakeCase) - case class Car( - model: String, - @Id id: Long, - topSpeed: Int, - @SqlName("vin") vinNumber: Option[Int], - color: Color - ) derives DbCodec - - val carRepo = ImmutableRepo[Car, Long] - val car = TableInfo[Car, Car, Long] - - val allCars = Vector( - Car("McLaren Senna", 1L, 208, Some(123), Color.Red), - Car("Ferrari F8 Tributo", 2L, 212, Some(124), Color.Green), - Car("Aston Martin Superleggera", 3L, 211, None, Color.Blue) - ) - - test("count"): - connect(ds()): - assertEquals(carRepo.count, 3L) - - test("existsById"): - connect(ds()): - assert(carRepo.existsById(3L)) - assert(!carRepo.existsById(4L)) - - test("findAll"): - val cars = connect(ds()): - carRepo.findAll - assertEquals(cars, allCars) - - test("findAll spec"): - connect(ds()): - val topSpeed = 211 - val spec = Spec[Car] - .where(sql"${car.topSpeed} > $topSpeed") - assertEquals(carRepo.findAll(spec), Vector(allCars(1))) - - test("findAll spec with multiple conditions"): - connect(ds()): - val topSpeed = 211 - val model = "Ferrari F8 Tributo" - val spec = Spec[Car] - .where(sql"${car.topSpeed} > $topSpeed") - .where(sql"${car.model} = $model") - assertEquals(carRepo.findAll(spec), Vector(allCars(1))) - - test("findById"): - connect(ds()): - assertEquals(carRepo.findById(3L).get, allCars.last) - assertEquals(carRepo.findById(4L), None) - - test("findAllByIds"): - connect(ds()): - assertEquals( - carRepo.findAllById(Vector(1L, 3L)).map(_.id), - Vector(1L, 3L) - ) - - test("repeatable read transaction"): - transact(ds(), withRepeatableRead): - assertEquals(carRepo.count, 3L) - - private def withRepeatableRead(con: Connection): Unit = - con.setTransactionIsolation(Connection.TRANSACTION_REPEATABLE_READ) - - test("select query"): - connect(ds()): - val minSpeed: Int = 210 - val query = - sql"select ${car.all} from $car where ${car.topSpeed} > $minSpeed" - .query[Car] - assertNoDiff( - query.frag.sqlString, - "select model, id, top_speed, vin, color from car where top_speed > ?" - ) - assertEquals(query.frag.params, Vector(minSpeed)) - assertEquals(query.run(), allCars.tail) - - test("select query with aliasing"): - connect(ds()): - val minSpeed = 210 - val cAlias = car.alias("c") - val query = - sql"select ${cAlias.all} from $cAlias where ${cAlias.topSpeed} > $minSpeed" - .query[Car] - assertNoDiff( - query.frag.sqlString, - "select c.model, c.id, c.top_speed, c.vin, c.color from car c where c.top_speed > ?" - ) - assertEquals(query.frag.params, Vector(minSpeed)) - assertEquals(query.run(), allCars.tail) - - test("select via option"): - connect(ds()): - val vin = Some(124) - val cars = - sql"select * from car where vin = $vin" - .query[Car] - .run() - assertEquals(cars, allCars.filter(_.vinNumber == vin)) - - test("tuple select"): - connect(ds()): - val tuples = sql"select model, color from car where id = 2" - .query[(String, Color)] - .run() - assertEquals(tuples, Vector(allCars(1).model -> allCars(1).color)) - - test("reads null int as None and not Some(0)"): - connect(ds()): - assertEquals(carRepo.findById(3L).get.vinNumber, None) - - /* - Repo Tests - */ - case class PersonCreator( - firstName: Option[String], - lastName: String, - isAdmin: Boolean, - socialId: Option[UUID] - ) derives DbCodec - - @Table(PostgresDbType, SqlNameMapper.CamelToSnakeCase) - case class Person( - id: Long, - firstName: Option[String], - lastName: String, - isAdmin: Boolean, - created: OffsetDateTime, - socialId: Option[UUID] - ) derives DbCodec - - val personRepo = Repo[PersonCreator, Person, Long] - val person = TableInfo[PersonCreator, Person, Long] - - test("delete"): - connect(ds()): - val p = personRepo.findById(1L).get - personRepo.delete(p) - assertEquals(personRepo.findById(1L), None) - - test("delete invalid"): - connect(ds()): - personRepo.delete(Person(23L, None, "", false, OffsetDateTime.now, None)) - assertEquals(8L, personRepo.count) - - test("deleteById"): - connect(ds()): - personRepo.deleteById(1L) - personRepo.deleteById(2L) - personRepo.deleteById(1L) - assertEquals(personRepo.findAll.size, 6) - - test("deleteAll"): - connect(ds()): - val p1 = personRepo.findById(1L).get - val p2 = p1.copy(id = 2L) - val p3 = p1.copy(id = 99L) - assertEquals( - personRepo.deleteAll(Vector(p1, p2, p3)), - BatchUpdateResult.Success(2) - ) - assertEquals(6L, personRepo.count) - - test("deleteAllById"): - connect(ds()): - assertEquals( - personRepo.deleteAllById(Vector(1L, 2L, 1L)), - BatchUpdateResult.Success(2) - ) - assertEquals(6L, personRepo.count) - - test("truncate"): - connect(ds()): - personRepo.truncate() - assertEquals(personRepo.count, 0L) - - test("insert"): - connect(ds()): - personRepo.insert( - PersonCreator( - firstName = Some("John"), - lastName = "Smith", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - ) - personRepo.insert( - PersonCreator( - firstName = None, - lastName = "Prince", - isAdmin = true, - socialId = None - ) - ) - assertEquals(personRepo.count, 10L) - assertEquals(personRepo.findById(9L).get.lastName, "Smith") - - test("insertReturning"): - connect(ds()): - val person = personRepo.insertReturning( - PersonCreator( - firstName = Some("John"), - lastName = "Smith", - isAdmin = false, - socialId = None - ) - ) - assertEquals(person.id, 9L) - assertEquals(person.lastName, "Smith") - - test("insertAllReturning"): - connect(ds()): - val newPc = Vector( - PersonCreator( - firstName = Some("Chandler"), - lastName = "Johnsored", - isAdmin = true, - socialId = Some(UUID.randomUUID()) - ), - PersonCreator( - firstName = None, - lastName = "Odysseus", - isAdmin = false, - socialId = None - ), - PersonCreator( - firstName = Some("Jorge"), - lastName = "Masvidal", - isAdmin = true, - socialId = None - ) - ) - val people = personRepo.insertAllReturning(newPc) - assertEquals(personRepo.count, 11L) - assertEquals(people.size, 3) - assertEquals(people.last.lastName, newPc.last.lastName) - - test("insert invalid"): - intercept[SqlException]: - connect(ds()): - val invalidP = PersonCreator(None, null, false, None) - personRepo.insert(invalidP) - - test("update"): - connect(ds()): - val p = personRepo.findById(1L).get - val updated = p.copy(firstName = None, isAdmin = false) - personRepo.update(updated) - assertEquals(personRepo.findById(1L).get, updated) - - test("update invalid"): - intercept[SqlException]: - connect(ds()): - val p = personRepo.findById(1L).get - val updated = p.copy(lastName = null) - personRepo.update(updated) - - test("insertAll"): - connect(ds()): - val newPeople = Vector( - PersonCreator( - firstName = Some("Chandler"), - lastName = "Johnsored", - isAdmin = true, - socialId = Some(UUID.randomUUID()) - ), - PersonCreator( - firstName = None, - lastName = "Odysseus", - isAdmin = false, - socialId = None - ), - PersonCreator( - firstName = Some("Jorge"), - lastName = "Masvidal", - isAdmin = true, - socialId = None - ) - ) - personRepo.insertAll(newPeople) - assertEquals(personRepo.count, 11L) - assertEquals( - personRepo.findById(11L).get.lastName, - newPeople.last.lastName - ) - - test("updateAll"): - connect(ds()): - val newPeople = Vector( - personRepo.findById(1L).get.copy(lastName = "Peterson"), - personRepo.findById(2L).get.copy(lastName = "Moreno") - ) - assertEquals( - personRepo.updateAll(newPeople), - BatchUpdateResult.Success(2) - ) - assertEquals(personRepo.findById(1L).get, newPeople(0)) - assertEquals(personRepo.findById(2L).get, newPeople(1)) - - test("transact"): - val count = transact(ds()): - val p = PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = false, - socialId = None - ) - personRepo.insert(p) - personRepo.count - assertEquals(count, 9L) - - test("transact failed"): - val dataSource = ds() - val p = PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = false, - socialId = None - ) - try - transact(dataSource): - personRepo.insert(p) - throw RuntimeException() - fail("should not reach") - catch - case _: Exception => - transact(dataSource): - assertEquals(personRepo.count, 8L) - - test("custom insert"): - connect(ds()): - val p = PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = false, - socialId = None - ) - val update = - sql"insert into $person ${person.insertColumns} values ($p)".update - assertNoDiff( - update.frag.sqlString, - "insert into person (first_name, last_name, is_admin, social_id) values (?, ?, ?, ?)" - ) - val rowsInserted = update.run() - assertEquals(rowsInserted, 1) - assertEquals(personRepo.count, 9L) - val fetched = personRepo.findAll.last - assertEquals(fetched.firstName, p.firstName) - assertEquals(fetched.lastName, p.lastName) - assertEquals(fetched.isAdmin, p.isAdmin) - - test("custom update"): - connect(ds()): - val p = personRepo.insertReturning( - PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - ) - val newIsAdmin = true - val update = - sql"update $person set ${person.isAdmin} = $newIsAdmin where ${person.id} = ${p.id}".update - assertNoDiff( - update.frag.sqlString, - "update person set is_admin = ? where id = ?" - ) - val rowsUpdated = update.run() - assertEquals(rowsUpdated, 1) - assertEquals(personRepo.findById(p.id).get.isAdmin, true) - - test("custom returning a single column"): - connect(ds()): - val returningQuery = - sql"insert into person (first_name, last_name, is_admin) values ('Arton', 'Senna', true) RETURNING id" - .returning[Long] - val personId = returningQuery.run().head - assertEquals(personId, 9L) - - test("custom returning multiple columns"): - connect(ds()): - val returningQuery = - sql"""insert into person (first_name, last_name, is_admin) values - ('Arton', 'Senna', true), - ('Demo', 'User', false) - RETURNING id, created""" - .returning[(Long, OffsetDateTime)] - val cols = returningQuery.run() - assertEquals(cols.map(_._1), Vector(9L, 10L)) - - test("custom returning with no rows updated"): - connect(ds()): - val statement = - sql"update person set first_name = 'xxx' where id = 12345 returning id" - .returning[Long] - val personIds = statement.run() - assert(personIds.isEmpty) - - @SqlName("person") - @Table(PostgresDbType, SqlNameMapper.CamelToSnakeCase) - case class CustomPerson( - id: Long, - firstName: Option[String], - lastName: String, - isAdmin: Boolean, - created: OffsetDateTime, - socialId: Option[UUID] - ) derives DbCodec - - val customPersonRepo = Repo[PersonCreator, CustomPerson, Long] - - test("count with manual table name"): - val count = connect(ds()): - customPersonRepo.count - assertEquals(count, 8L) - - test(".query iterator"): - connect(ds()): - Using.Manager(implicit use => - val it = sql"SELECT * FROM person".query[Person].iterator() - assertEquals(it.map(_.id).size, 8) - ) - - test(".returning iterator"): - connect(ds()): - Using.Manager(implicit use => - val it = sql"UPDATE person set is_admin = false RETURNING first_name" - .returning[String] - .iterator() - assertEquals(it.size, 8) - ) - - test("embed Frag into Frag"): - def findPersonCnt(filter: Frag, limit: Long = 1)(using DbCon): Int = - val offsetFrag = sql"OFFSET 0" - val limitFrag = sql"LIMIT $limit" - sql"SELECT count(*) FROM person WHERE $filter $limitFrag $offsetFrag" - .query[Int] - .run() - .head - val isAdminFrag = sql"is_admin = true" - connect(ds()): - val johnCnt = findPersonCnt(sql"$isAdminFrag AND first_name = 'John'", 2) - assertEquals(johnCnt, 2) - - @Table(PostgresDbType, SqlNameMapper.CamelToSnakeCase) - case class BigDec(id: Int, myBigDec: Option[BigDecimal]) derives DbCodec - - val bigDecRepo = Repo[BigDec, BigDec, Int] - - test("option of bigdecimal"): - connect(ds()): - val bigDec1 = bigDecRepo.findById(1).get - assertEquals(bigDec1.myBigDec, Some(BigDecimal(123))) - val bigDec2 = bigDecRepo.findById(2).get - assertEquals(bigDec2.myBigDec, None) + sharedTests(this, PostgresDbType, xa) val pgContainer = ForAllContainerFixture( PostgreSQLContainer @@ -493,47 +20,28 @@ class PgTests extends FunSuite, TestContainersFixtures: .createContainer() ) - @Table(PostgresDbType, SqlNameMapper.CamelToSnakeCase) - case class NoId( - createdAt: OffsetDateTime, - userName: String, - userAction: String - ) derives DbCodec - - val noIdRepo = Repo[NoId, NoId, Null]() - - test("insert NoId entities"): - connect(ds()): - val entity = NoId(OffsetDateTime.now, "Dan", "Fishing") - noIdRepo.insert(entity) - assert(noIdRepo.findAll.exists(_.userName == "Dan")) - override def munitFixtures: Seq[AnyFixture[_]] = super.munitFixtures :+ pgContainer - def ds(): DataSource = + def xa(): Transactor = val ds = PGSimpleDataSource() val pg = pgContainer() ds.setUrl(pg.jdbcUrl) ds.setUser(pg.username) ds.setPassword(pg.password) - val carSql = - Files.readString(Path.of(getClass.getResource("/pg-car.sql").toURI)) - val personSql = - Files.readString(Path.of(getClass.getResource("/pg-person.sql").toURI)) - val bigDecSql = - Files.readString(Path.of(getClass.getResource("/pg-bigdec.sql").toURI)) - val noIdSql = - Files.readString(Path.of(getClass.getResource("/pg-no-id.sql").toURI)) + val tableDDLs = Vector( + "/pg/car.sql", + "/pg/person.sql", + "/pg/my-user.sql", + "/pg/no-id.sql", + "/pg/big-dec.sql" + ).map(p => Files.readString(Path.of(getClass.getResource(p).toURI))) Manager(use => val con = use(ds.getConnection) val stmt = use(con.createStatement) - stmt.execute(carSql) - stmt.execute(personSql) - stmt.execute(bigDecSql) - stmt.execute(noIdSql) + for ddl <- tableDDLs do stmt.execute(ddl) ).get - ds - end ds + Transactor(ds) + end xa end PgTests diff --git a/magnum/src/test/scala/SpecTests.scala b/magnum/src/test/scala/SpecTests.scala deleted file mode 100644 index e07545e..0000000 --- a/magnum/src/test/scala/SpecTests.scala +++ /dev/null @@ -1,116 +0,0 @@ -import com.augustnagro.magnum.* -import munit.FunSuite - -class SpecTests extends FunSuite: - - case class User(id: Long, name: String, age: opaques.Age) derives DbCodec - - test("select all"): - assertEquals(Spec[User].build.sqlString, "") - - test("empty predicate"): - assertEquals(Spec[User].where(sql"").build.sqlString, "") - - test("predicate having param at end"): - val age = 3 - val frag = Spec[User] - .where(sql"age > $age") - .build - assertEquals(frag.sqlString, "WHERE (age > ?)") - assertEquals(frag.params, Vector(age)) - - test("predicate having param at start"): - val age = 3 - val spec = Spec[User] - .where(sql"$age < age") - .build - assertEquals(spec.sqlString, "WHERE (? < age)") - assertEquals(spec.params, Vector(age)) - - test("AND in where predicate"): - val name = "AUGUST" - val age = 3 - val spec = Spec[User] - .where(sql"age > $age AND $name = upper(name)") - .build - assertEquals( - spec.sqlString, - "WHERE (age > ? AND ? = upper(name))" - ) - assertEquals(spec.params, Vector(age, name)) - - test("multiple where predicates"): - val name = "AUGUST" - val age = 3 - val spec = Spec[User] - .where(sql"age > $age") - .where(sql"$name = upper(name)") - .build - assertEquals( - spec.sqlString, - "WHERE (age > ?) AND (? = upper(name))" - ) - assertEquals(spec.params, Vector(age, name)) - - test("orderBy"): - val spec = Spec[User] - .orderBy("name", SortOrder.Asc, NullOrder.Last) - .build - assertEquals(spec.sqlString, "ORDER BY name ASC NULLS LAST") - assertEquals(spec.params, Vector.empty) - - test("limit"): - val spec = Spec[User] - .limit(99) - .build - assertEquals(spec.sqlString, "LIMIT 99") - assertEquals(spec.params, Vector.empty) - - test("offset"): - val spec = Spec[User] - .offset(100) - .build - assertEquals(spec.sqlString, "OFFSET 100") - assertEquals(spec.params, Vector.empty) - - test("seek"): - val age = 3 - val spec = Spec[User] - .seek("age", SeekDir.Gt, age, SortOrder.Asc) - .build - assertEquals( - spec.sqlString, - "WHERE (age > ?) ORDER BY age ASC NULLS LAST" - ) - assertEquals(spec.params, Vector(age)) - - test("seek multiple"): - val age = 3 - val name = "John" - val spec = Spec[User] - .seek("age", SeekDir.Gt, age, SortOrder.Asc) - .seek("name", SeekDir.Lt, name, SortOrder.Desc, NullOrder.First) - .build - assertEquals( - spec.sqlString, - "WHERE (age > ?) AND (name < ?) ORDER BY age ASC NULLS LAST, name DESC NULLS FIRST" - ) - assertEquals(spec.params, Vector(age, name)) - - test("everything"): - val idOpt = Option.empty[Long] - val age = 3 - val name = "John" - val spec = Spec[User] - .where(idOpt.map(id => sql"id = $id").getOrElse(sql"")) - .where(sql"age > $age") - .orderBy("age", SortOrder.Asc, NullOrder.Last) - .limit(10) - .seek("name", SeekDir.Lt, name, SortOrder.Desc) - .build - assertEquals( - spec.sqlString, - "WHERE (age > ?) AND (name < ?) ORDER BY age ASC NULLS LAST, name DESC NULLS LAST LIMIT 10" - ) - assertEquals(spec.params, Vector(age, name)) -end SpecTests diff --git a/magnum/src/test/scala/SqliteTests.scala b/magnum/src/test/scala/SqliteTests.scala index 9f57672..667a99c 100644 --- a/magnum/src/test/scala/SqliteTests.scala +++ b/magnum/src/test/scala/SqliteTests.scala @@ -2,430 +2,33 @@ import com.augustnagro.magnum.* import com.augustnagro.magnum.UUIDCodec.VarCharUUIDCodec import munit.FunSuite import org.sqlite.SQLiteDataSource +import shared.* -import java.nio.file.{Files, Path} -import java.sql.Connection -import java.time.{LocalDateTime, OffsetDateTime} +import java.nio.file.Files +import java.time.OffsetDateTime import java.util.UUID -import javax.sql.DataSource import scala.util.Using import scala.util.Using.Manager class SqliteTests extends FunSuite: - /* - Immutable Repo Tests - */ + given DbCodec[OffsetDateTime] = + DbCodec[String].biMap(OffsetDateTime.parse, _.toString) - enum Color derives DbCodec: - case Red, Green, Blue + given DbCodec[UUID] = + DbCodec[String].biMap(UUID.fromString, _.toString) - @Table(SqliteDbType, SqlNameMapper.CamelToSnakeCase) - case class Car( - model: String, - @Id id: Long, - topSpeed: Int, - @SqlName("vin") vinNumber: Option[Int], - color: Color - ) derives DbCodec + given DbCodec[Boolean] = + DbCodec[Int].biMap(_ != 0, b => if b then 1 else 0) - val carRepo = ImmutableRepo[Car, Long] - val car = TableInfo[Car, Car, Long] + given DbCodec[BigDecimal] = + DbCodec[String].biMap(BigDecimal.apply, _.toString()) - val allCars = Vector( - Car("McLaren Senna", 1L, 208, Some(123), Color.Red), - Car("Ferrari F8 Tributo", 2L, 212, Some(124), Color.Green), - Car("Aston Martin Superleggera", 3L, 211, None, Color.Blue) - ) - - test("count"): - val count = connect(ds()): - carRepo.count - assertEquals(count, 3L) - - test("existsById"): - connect(ds()): - assert(carRepo.existsById(3L)) - assert(!carRepo.existsById(4L)) - - test("findAll"): - val cars = connect(ds()): - carRepo.findAll - assertEquals(cars, allCars) - - test("findAll spec"): - connect(ds()): - val topSpeed = 211 - val spec = Spec[Car] - .where(sql"${car.topSpeed} > $topSpeed") - assertEquals(carRepo.findAll(spec), Vector(allCars(1))) - - test("findById"): - connect(ds()): - assertEquals(carRepo.findById(3L).get, allCars.last) - assertEquals(carRepo.findById(4L), None) - - test("findAllByIds"): - intercept[UnsupportedOperationException]: - connect(ds()): - assertEquals( - carRepo.findAllById(Vector(1L, 3L)).map(_.id), - Vector(1L, 3L) - ) - - test("repeatable read transaction"): - transact(ds(), withRepeatableRead): - assertEquals(carRepo.count, 3L) - - private def withRepeatableRead(con: Connection): Unit = - con.setTransactionIsolation(Connection.TRANSACTION_REPEATABLE_READ) - - test("select query"): - connect(ds()): - val minSpeed: Int = 210 - val query = - sql"select ${car.all} from $car where ${car.topSpeed} > $minSpeed" - .query[Car] - assertNoDiff( - query.frag.sqlString, - "select model, id, top_speed, vin, color from car where top_speed > ?" - ) - assertEquals(query.frag.params, Vector(minSpeed)) - assertEquals(query.run(), allCars.tail) - - test("select query with aliasing"): - connect(ds()): - val minSpeed = 210 - val cAlias = car.alias("c") - val query = - sql"select ${cAlias.all} from $cAlias where ${cAlias.topSpeed} > $minSpeed" - .query[Car] - assertNoDiff( - query.frag.sqlString, - "select c.model, c.id, c.top_speed, c.vin, c.color from car c where c.top_speed > ?" - ) - assertEquals(query.frag.params, Vector(minSpeed)) - assertEquals(query.run(), allCars.tail) - - test("select via option"): - connect(ds()): - val vin = Some(124) - val cars = - sql"select * from car where vin = $vin" - .query[Car] - .run() - assertEquals(cars, allCars.filter(_.vinNumber == vin)) - - test("tuple select"): - connect(ds()): - val tuples = sql"select model, color from car where id = 2" - .query[(String, Color)] - .run() - assertEquals(tuples, Vector(allCars(1).model -> allCars(1).color)) - - test("reads null int as None and not Some(0)"): - connect(ds()): - assertEquals(carRepo.findById(3L).get.vinNumber, None) - - /* - Repo Tests - */ - case class PersonCreator( - firstName: Option[String], - lastName: String, - isAdmin: Boolean, - socialId: Option[UUID] - ) derives DbCodec - - @Table(SqliteDbType, SqlNameMapper.CamelToSnakeCase) - case class Person( - id: Long, - firstName: Option[String], - lastName: String, - isAdmin: Boolean, - created: String, - socialId: Option[UUID] - ) derives DbCodec - - val personRepo = Repo[PersonCreator, Person, Long] - val person = TableInfo[PersonCreator, Person, Long] - - test("delete"): - connect(ds()): - val p = personRepo.findById(1L).get - personRepo.delete(p) - assertEquals(personRepo.findById(1L), None) - - test("delete invalid"): - connect(ds()): - personRepo.delete( - Person(23L, None, "", false, LocalDateTime.now.toString, None) - ) - assertEquals(8L, personRepo.count) - - test("deleteById"): - connect(ds()): - personRepo.deleteById(1L) - personRepo.deleteById(2L) - personRepo.deleteById(1L) - assertEquals(personRepo.findAll.size, 6) - - test("deleteAll"): - connect(ds()): - val p1 = personRepo.findById(1L).get - val p2 = p1.copy(id = 2L) - val p3 = p1.copy(id = 99L) - personRepo.deleteAll(Vector(p1, p2, p3)) - assertEquals(6L, personRepo.count) - - test("deleteAllById"): - connect(ds()): - personRepo.deleteAllById(Vector(1L, 2L, 1L)) - assertEquals(6L, personRepo.count) - - test("truncate"): - connect(ds()): - personRepo.truncate() - assertEquals(personRepo.count, 0L) - - test("insert"): - connect(ds()): - personRepo.insert( - PersonCreator( - firstName = Some("John"), - lastName = "Smith", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - ) - personRepo.insert( - PersonCreator( - firstName = None, - lastName = "Prince", - isAdmin = true, - socialId = None - ) - ) - assertEquals(personRepo.count, 10L) - assertEquals(personRepo.findById(9L).get.lastName, "Smith") - - test("insertReturning"): - connect(ds()): - val person = personRepo.insertReturning( - PersonCreator( - firstName = Some("John"), - lastName = "Smith", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - ) - assertEquals(person.id, 9L) - assertEquals(person.lastName, "Smith") - - test("insertAllReturning"): - intercept[UnsupportedOperationException]: - connect(ds()): - val newPc = Vector( - PersonCreator( - firstName = Some("Chandler"), - lastName = "Johnsored", - isAdmin = true, - socialId = Some(UUID.randomUUID()) - ), - PersonCreator( - firstName = None, - lastName = "Odysseus", - isAdmin = false, - socialId = None - ), - PersonCreator( - firstName = Some("Jorge"), - lastName = "Masvidal", - isAdmin = true, - socialId = None - ) - ) - val people = personRepo.insertAllReturning(newPc) - assertEquals(personRepo.count, 11L) - println(people) - assertEquals(people.size, 3) - assertEquals(people.last.lastName, newPc.last.lastName) - - test("custom returning"): - connect(ds()): - val returningQuery = - sql"insert into person (first_name, last_name, is_admin) values ('Arton', 'Senna', true) RETURNING id" - .returning[Long] - val personId = returningQuery.run().head - assertEquals(personId, 9L) - - test("custom returning multiple columns"): - connect(ds()): - val returningQuery = - sql"""insert into person (first_name, last_name, is_admin) values - ('Arton', 'Senna', true), - ('Demo', 'User', false) - RETURNING id""" - .returning[Long] - val cols = returningQuery.run() - assertEquals(cols, Vector(9L, 10L)) - - test("custom returning with no rows updated"): - connect(ds()): - val statement = - sql"update person set first_name = 'xxx' where id = 12345 returning id" - .returning[Long] - val personIds = statement.run() - assert(personIds.isEmpty) - - test("insert invalid"): - intercept[SqlException]: - connect(ds()): - val invalidP = PersonCreator(None, null, false, None) - personRepo.insert(invalidP) - - test("update"): - connect(ds()): - val p = personRepo.findById(1L).get - val updated = p.copy(firstName = None) - personRepo.update(updated) - assertEquals(personRepo.findById(1L).get, updated) - - test("update invalid"): - intercept[SqlException]: - connect(ds()): - val p = personRepo.findById(1L).get - val updated = p.copy(lastName = null) - personRepo.update(updated) - - test("insertAll"): - connect(ds()): - val newPeople = Vector( - PersonCreator( - firstName = Some("Chandler"), - lastName = "Johnsored", - isAdmin = true, - socialId = Some(UUID.randomUUID()) - ), - PersonCreator( - firstName = None, - lastName = "Odysseus", - isAdmin = false, - socialId = None - ), - PersonCreator( - firstName = Some("Jorge"), - lastName = "Masvidal", - isAdmin = true, - socialId = None - ) - ) - personRepo.insertAll(newPeople) - assertEquals(personRepo.count, 11L) - assertEquals( - personRepo.findById(11L).get.lastName, - newPeople.last.lastName - ) - - test("updateAll"): - connect(ds()): - val newPeople = Vector( - personRepo.findById(1L).get.copy(lastName = "Peterson"), - personRepo.findById(2L).get.copy(lastName = "Moreno") - ) - personRepo.updateAll(newPeople) - assertEquals(personRepo.findById(1L).get, newPeople(0)) - assertEquals(personRepo.findById(2L).get, newPeople(1)) - - test("transact"): - val count = transact(ds()): - val p = PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - personRepo.insert(p) - personRepo.count - assertEquals(count, 9L) - - test("transact failed"): - val dataSource = ds() - val p = PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - try - transact(dataSource): - personRepo.insert(p) - throw RuntimeException() - fail("should not reach") - catch - case _: Exception => - transact(dataSource): - assertEquals(personRepo.count, 8L) - - test("custom insert"): - connect(ds()): - val p = PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - val update = - sql"insert into $person ${person.insertColumns} values ($p)".update - assertNoDiff( - update.frag.sqlString, - "insert into person (first_name, last_name, is_admin, social_id) values (?, ?, ?, ?)" - ) - val rowsInserted = update.run() - assertEquals(rowsInserted, 1) - assertEquals(personRepo.count, 9L) - val fetched = personRepo.findAll.last - assertEquals(fetched.firstName, p.firstName) - assertEquals(fetched.lastName, p.lastName) - assertEquals(fetched.isAdmin, p.isAdmin) - - test("custom update"): - connect(ds()): - val p = personRepo.insertReturning( - PersonCreator( - firstName = Some("Chandler"), - lastName = "Brown", - isAdmin = false, - socialId = Some(UUID.randomUUID()) - ) - ) - val newIsAdmin = true - val update = - sql"update $person set ${person.isAdmin} = $newIsAdmin where ${person.id} = ${p.id}".update - assertNoDiff( - update.frag.sqlString, - "update person set is_admin = ? where id = ?" - ) - val rowsUpdated = update.run() - assertEquals(rowsUpdated, 1) - assertEquals(personRepo.findById(p.id).get.isAdmin, true) - - test("embed Frag into Frag"): - def findPersonCnt(filter: Frag, limit: Long = 1)(using DbCon): Int = - val offsetFrag = sql"OFFSET 0" - val limitFrag = sql"LIMIT $limit" - sql"SELECT count(*) FROM person WHERE $filter $limitFrag $offsetFrag" - .query[Int] - .run() - .head - val isAdminFrag = sql"is_admin = true" - connect(ds()): - val johnCnt = findPersonCnt(sql"$isAdminFrag AND first_name = 'John'", 2) - assertEquals(johnCnt, 2) + sharedTests(this, SqliteDbType, xa) lazy val sqliteDbPath = Files.createTempFile(null, ".db").toAbsolutePath - def ds(): DataSource = + def xa(): Transactor = val ds = SQLiteDataSource() ds.setUrl("jdbc:sqlite:" + sqliteDbPath) Manager(use => @@ -436,16 +39,17 @@ class SqliteTests extends FunSuite: """create table car ( | model text not null, | id integer primary key, - | top_speed integer, + | top_speed integer not null, | vin integer, - | color text check (color in ('Red', 'Green', 'Blue')) not null + | color text check (color in ('Red', 'Green', 'Blue')) not null, + | created text not null |)""".stripMargin ) stmt.execute( - """insert into car (model, top_speed, vin, color) values - |('McLaren Senna', 208, 123, 'Red'), - |('Ferrari F8 Tributo', 212, 124, 'Green'), - |('Aston Martin Superleggera', 211, null, 'Blue')""".stripMargin + """insert into car (model, id, top_speed, vin, color, created) values + |('McLaren Senna', 1, 208, 123, 'Red', '2024-11-24T22:17:30.000000000Z'), + |('Ferrari F8 Tributo', 2, 212, 124, 'Green', '2024-11-24T22:17:31.000000000Z'), + |('Aston Martin Superleggera', 3, 211, null, 'Blue', '2024-11-24T22:17:32.000000000Z')""".stripMargin ) stmt.execute("drop table if exists person") stmt.execute( @@ -454,22 +58,61 @@ class SqliteTests extends FunSuite: | first_name text, | last_name text not null, | is_admin integer not null, - | created text default(datetime()), + | created text not null, | social_id varchar(36) |)""".stripMargin ) stmt.execute( - """insert into person (first_name, last_name, is_admin, created, social_id) values - |('George', 'Washington', true, datetime(), 'd06443a6-3efb-46c4-a66a-a80a8a9a5388'), - |('Alexander', 'Hamilton', true, datetime(), '529b6c6d-7228-4da5-81d7-13b706f78ddb'), - |('John', 'Adams', true, datetime(), null), - |('Benjamin', 'Franklin', true, datetime(), null), - |('John', 'Jay', true, datetime(), null), - |('Thomas', 'Jefferson', true, datetime(), null), - |('James', 'Madison', true, datetime(), null), - |(null, 'Nagro', false, datetime(), null)""".stripMargin + """insert into person (id, first_name, last_name, is_admin, created, social_id) values + |(1, 'George', 'Washington', true, '2024-11-24T22:17:30.000000000Z', 'd06443a6-3efb-46c4-a66a-a80a8a9a5388'), + |(2, 'Alexander', 'Hamilton', true, '2024-11-24T22:17:30.000000000Z', '529b6c6d-7228-4da5-81d7-13b706f78ddb'), + |(3, 'John', 'Adams', true, '2024-11-24T22:17:30.000000000Z', null), + |(4, 'Benjamin', 'Franklin', true, '2024-11-24T22:17:30.000000000Z', null), + |(5, 'John', 'Jay', true, '2024-11-24T22:17:30.000000000Z', null), + |(6, 'Thomas', 'Jefferson', true, '2024-11-24T22:17:30.000000000Z', null), + |(7, 'James', 'Madison', true, '2024-11-24T22:17:30.000000000Z', null), + |(8, null, 'Nagro', false, '2024-11-24T22:17:30.000000000Z', null)""".stripMargin + ) + stmt.execute("drop table if exists my_user") + stmt.execute( + """create table my_user ( + | first_name text not null, + | id integer primary key + |)""".stripMargin + ) + stmt.execute( + """insert into my_user (first_name) values + |('George'), + |('Alexander'), + |('John')""".stripMargin + ) + stmt.execute("drop table if exists no_id") + stmt.execute( + """create table no_id ( + | created_at text not null, + | user_name text not null, + | user_action text not null + |)""".stripMargin + ) + stmt.execute( + """insert into no_id values + |('2024-11-24T22:17:30.000000000Z', 'Josh', 'clicked a button'), + |('2024-11-24T22:17:30.000000000Z', 'Danny', 'opened a toaster'), + |('2024-11-24T22:17:30.000000000Z', 'Greg', 'ran some QA tests');""".stripMargin + ) + stmt.execute("drop table if exists big_dec") + stmt.execute( + """create table big_dec ( + | id integer primary key, + | my_big_dec text + |)""".stripMargin + ) + stmt.execute( + """insert into big_dec values + |(1, '123'), + |(2, null)""".stripMargin ) ).get - ds - end ds + Transactor(ds) + end xa end SqliteTests diff --git a/magnum/src/test/scala/shared/BigDecTests.scala b/magnum/src/test/scala/shared/BigDecTests.scala new file mode 100644 index 0000000..c4b4724 --- /dev/null +++ b/magnum/src/test/scala/shared/BigDecTests.scala @@ -0,0 +1,22 @@ +package shared + +import com.augustnagro.magnum.* +import munit.FunSuite + +def bigDecTests(suite: FunSuite, dbType: DbType, xa: () => Transactor)(using + munit.Location, + DbCodec[BigDecimal] +): Unit = + import suite.* + + @Table(dbType, SqlNameMapper.CamelToSnakeCase) + case class BigDec(id: Int, myBigDec: Option[BigDecimal]) derives DbCodec + + val bigDecRepo = Repo[BigDec, BigDec, Int] + + test("option of bigdecimal"): + connect(xa()): + val bigDec1 = bigDecRepo.findById(1).get + assert(bigDec1.myBigDec == Some(BigDecimal(123))) + val bigDec2 = bigDecRepo.findById(2).get + assert(bigDec2.myBigDec == None) diff --git a/magnum/src/test/scala/shared/Color.scala b/magnum/src/test/scala/shared/Color.scala new file mode 100644 index 0000000..6baf18a --- /dev/null +++ b/magnum/src/test/scala/shared/Color.scala @@ -0,0 +1,6 @@ +package shared + +import com.augustnagro.magnum.DbCodec + +enum Color derives DbCodec: + case Red, Green, Blue diff --git a/magnum/src/test/scala/shared/EmbeddedFragTests.scala b/magnum/src/test/scala/shared/EmbeddedFragTests.scala new file mode 100644 index 0000000..6ed015e --- /dev/null +++ b/magnum/src/test/scala/shared/EmbeddedFragTests.scala @@ -0,0 +1,24 @@ +package shared + +import com.augustnagro.magnum.* +import munit.FunSuite + +def embeddedFragTests(suite: FunSuite, dbType: DbType, xa: () => Transactor)( + using munit.Location +): Unit = + import suite.* + + test("embed Frag into Frag"): + def findPersonCnt(filter: Frag)(using DbCon): Int = + val x = sql"first_name IS NOT NULL" + sql"SELECT count(*) FROM person WHERE $filter AND $x" + .query[Int] + .run() + .head + val isAdminFrag = + if dbType == OracleDbType then sql"is_admin = 'Y'" + else sql"is_admin = true" + connect(xa()): + val johnCnt = + findPersonCnt(sql"$isAdminFrag AND first_name = 'John'") + assert(johnCnt == 2) diff --git a/magnum/src/test/scala/shared/EntityCreatorTests.scala b/magnum/src/test/scala/shared/EntityCreatorTests.scala new file mode 100644 index 0000000..095b8f2 --- /dev/null +++ b/magnum/src/test/scala/shared/EntityCreatorTests.scala @@ -0,0 +1,114 @@ +package shared + +import com.augustnagro.magnum.* +import munit.{FunSuite, Location} + +import scala.util.Using + +def entityCreatorTests(suite: FunSuite, dbType: DbType, xa: () => Transactor)( + using Location +): Unit = + import suite.* + if dbType == ClickhouseDbType then return + + case class MyUserCreator(firstName: String) derives DbCodec + + @Table(dbType, SqlNameMapper.CamelToSnakeCase) + case class MyUser(firstName: String, id: Long) derives DbCodec + + val userRepo = Repo[MyUserCreator, MyUser, Long] + val user = TableInfo[MyUserCreator, MyUser, Long] + + test("insert EntityCreator"): + connect(xa()): + userRepo.insert(MyUserCreator("Ash")) + userRepo.insert(MyUserCreator("Steve")) + assert(userRepo.count == 5L) + assert(userRepo.findAll.map(_.firstName).contains("Steve")) + + test("insertReturning EntityCreator"): + assume(dbType != MySqlDbType) + assume(dbType != SqliteDbType) + connect(xa()): + val user = userRepo.insertReturning(MyUserCreator("Ash")) + assert(user.firstName == "Ash") + + test("insertAllReturning EntityCreator"): + assume(dbType != MySqlDbType) + assume(dbType != SqliteDbType) + connect(xa()): + val newUsers = Vector( + MyUserCreator("Ash"), + MyUserCreator("Steve"), + MyUserCreator("Josh") + ) + val users = userRepo.insertAllReturning(newUsers) + assert(userRepo.count == 6L) + assert(users.size == 3) + assert(users.last.firstName == newUsers.last.firstName) + + test("insert invalid EntityCreator"): + intercept[SqlException]: + connect(xa()): + val invalidUser = MyUserCreator(null) + userRepo.insert(invalidUser) + + test("insertAll EntityCreator"): + connect(xa()): + val newUsers = Vector( + MyUserCreator("Ash"), + MyUserCreator("Steve"), + MyUserCreator("Josh") + ) + userRepo.insertAll(newUsers) + assert(userRepo.count == 6L) + assert( + userRepo.findAll.map(_.firstName).contains(newUsers.last.firstName) + ) + + test("custom insert EntityCreator"): + connect(xa()): + val u = MyUserCreator("Ash") + val update = + sql"insert into $user ${user.insertColumns} values ($u)".update + assertNoDiff( + update.frag.sqlString, + "insert into my_user (first_name) values (?)" + ) + val rowsInserted = update.run() + assert(rowsInserted == 1) + assert(userRepo.count == 4L) + assert(userRepo.findAll.exists(_.firstName == "Ash")) + + test("custom update EntityCreator"): + connect(xa()): + val u = userRepo.findAll.head + val newName = "Ash" + val update = + sql"update $user set ${user.firstName} = $newName where ${user.id} = ${u.id}".update + assertNoDiff( + update.frag.sqlString, + "update my_user set first_name = ? where id = ?" + ) + val rowsUpdated = update.run() + assert(rowsUpdated == 1) + assert(userRepo.findAll.exists(_.firstName == "Ash")) + + test(".returning iterator"): + assume(dbType != MySqlDbType) + assume(dbType != SqliteDbType) + connect(xa()): + Using.Manager(implicit use => + val it = + if dbType == H2DbType then + sql"INSERT INTO $user ${user.insertColumns} VALUES ('Bob')" + .returningKeys[Long](user.id) + .iterator() + else + sql"INSERT INTO $user ${user.insertColumns} VALUES ('Bob') RETURNING ${user.id}" + .returning[Long] + .iterator() + assert(it.size == 1) + ) + +end entityCreatorTests diff --git a/magnum/src/test/scala/shared/ImmutableRepoTests.scala b/magnum/src/test/scala/shared/ImmutableRepoTests.scala new file mode 100644 index 0000000..fbe4022 --- /dev/null +++ b/magnum/src/test/scala/shared/ImmutableRepoTests.scala @@ -0,0 +1,149 @@ +package shared + +import com.augustnagro.magnum.* +import munit.FunSuite + +import java.sql.Connection +import java.time.{OffsetDateTime, ZoneOffset} +import scala.util.Using + +def immutableRepoTests(suite: FunSuite, dbType: DbType, xa: () => Transactor)( + using + munit.Location, + DbCodec[OffsetDateTime] +): Unit = + import suite.* + + @Table(dbType, SqlNameMapper.CamelToSnakeCase) + case class Car( + model: String, + @Id id: Long, + topSpeed: Int, + @SqlName("vin") vinNumber: Option[Int], + color: Color, + created: OffsetDateTime + ) derives DbCodec + + val carRepo = ImmutableRepo[Car, Long] + val car = TableInfo[Car, Car, Long] + + val allCars = Vector( + Car( + model = "McLaren Senna", + id = 1L, + topSpeed = 208, + vinNumber = Some(123), + color = Color.Red, + created = OffsetDateTime.parse("2024-11-24T22:17:30.000000000Z") + ), + Car( + model = "Ferrari F8 Tributo", + id = 2L, + topSpeed = 212, + vinNumber = Some(124), + color = Color.Green, + created = OffsetDateTime.parse("2024-11-24T22:17:31.000000000Z") + ), + Car( + model = "Aston Martin Superleggera", + id = 3L, + topSpeed = 211, + vinNumber = None, + color = Color.Blue, + created = OffsetDateTime.parse("2024-11-24T22:17:32.000000000Z") + ) + ) + + test("count"): + connect(xa()): + assert(carRepo.count == 3L) + + test("existsById"): + connect(xa()): + assert(carRepo.existsById(3L)) + assert(!carRepo.existsById(4L)) + + test("findAll"): + val cars = connect(xa()): + carRepo.findAll + assert(cars == allCars) + + test("findById"): + connect(xa()): + assert(carRepo.findById(3L).get == allCars.last) + assert(carRepo.findById(4L) == None) + + test("findAllByIds"): + assume(dbType != ClickhouseDbType) + assume(dbType != MySqlDbType) + assume(dbType != OracleDbType) + assume(dbType != SqliteDbType) + connect(xa()): + val ids = carRepo.findAllById(Vector(1L, 3L)).map(_.id) + assert(ids == Vector(1L, 3L)) + + test("serializable transaction"): + transact(xa().copy(connectionConfig = withSerializable)): + assert(carRepo.count == 3L) + + def withSerializable(con: Connection): Unit = + con.setTransactionIsolation(Connection.TRANSACTION_SERIALIZABLE) + + test("select query"): + connect(xa()): + val minSpeed: Int = 210 + val query = + sql"select ${car.all} from $car where ${car.topSpeed} > $minSpeed" + .query[Car] + assertNoDiff( + query.frag.sqlString, + "select model, id, top_speed, vin, color, created from car where top_speed > ?" + ) + assert(query.frag.params == Vector(minSpeed)) + assert(query.run() == allCars.tail) + + test("select query with aliasing"): + connect(xa()): + val minSpeed = 210 + val cAlias = car.alias("c") + val query = + sql"select ${cAlias.all} from $cAlias where ${cAlias.topSpeed} > $minSpeed" + .query[Car] + assertNoDiff( + query.frag.sqlString, + "select c.model, c.id, c.top_speed, c.vin, c.color, c.created from car c where c.top_speed > ?" + ) + assert(query.frag.params == Vector(minSpeed)) + assert(query.run() == allCars.tail) + + test("select via option"): + connect(xa()): + val vin = Some(124) + val cars = + sql"select * from car where vin = $vin" + .query[Car] + .run() + assert(cars == allCars.filter(_.vinNumber == vin)) + + test("tuple select"): + connect(xa()): + val tuples = sql"select model, color from car where id = 2" + .query[(String, Color)] + .run() + assert(tuples == Vector(allCars(1).model -> allCars(1).color)) + + test("reads null int as None and not Some(0)"): + connect(xa()): + assert(carRepo.findById(3L).get.vinNumber == None) + + test("created timestamps should match"): + connect(xa()): + assert(carRepo.findAll.map(_.created) == allCars.map(_.created)) + + test(".query iterator"): + connect(xa()): + Using.Manager(implicit use => + val it = sql"SELECT * FROM car".query[Car].iterator() + assert(it.map(_.id).size == 3) + ) +end immutableRepoTests diff --git a/magnum/src/test/scala/shared/NoIdTests.scala b/magnum/src/test/scala/shared/NoIdTests.scala new file mode 100644 index 0000000..2e30cf9 --- /dev/null +++ b/magnum/src/test/scala/shared/NoIdTests.scala @@ -0,0 +1,28 @@ +package shared + +import com.augustnagro.magnum.* +import munit.FunSuite + +import java.time.OffsetDateTime + +def noIdTests(suite: FunSuite, dbType: DbType, xa: () => Transactor)(using + munit.Location, + DbCodec[OffsetDateTime] +): Unit = + import suite.* + + @Table(dbType, SqlNameMapper.CamelToSnakeCase) + case class NoId( + createdAt: OffsetDateTime, + userName: String, + userAction: String + ) derives DbCodec + + val noIdRepo = Repo[NoId, NoId, Null]() + + test("insert NoId entities"): + connect(xa()): + val entity = NoId(OffsetDateTime.now, "Dan", "Fishing") + noIdRepo.insert(entity) + assert(noIdRepo.findAll.exists(_.userName == "Dan")) +end noIdTests diff --git a/magnum/src/test/scala/shared/OptionalProductTests.scala b/magnum/src/test/scala/shared/OptionalProductTests.scala new file mode 100644 index 0000000..9f55b76 --- /dev/null +++ b/magnum/src/test/scala/shared/OptionalProductTests.scala @@ -0,0 +1,35 @@ +package shared + +import com.augustnagro.magnum.* +import munit.{FunSuite, Location} + +import java.time.OffsetDateTime + +def optionalProductTests( + suite: FunSuite, + dbType: DbType, + xa: () => Transactor +)(using Location, DbCodec[BigDecimal], DbCodec[OffsetDateTime]): Unit = + import suite.* + + @Table(dbType, SqlNameMapper.CamelToSnakeCase) + case class Car( + model: String, + @Id id: Long, + topSpeed: Int, + @SqlName("vin") vinNumber: Option[Int], + color: Color, + created: OffsetDateTime + ) derives DbCodec + + @Table(dbType, SqlNameMapper.CamelToSnakeCase) + case class BigDec(id: Int, myBigDec: Option[BigDecimal]) derives DbCodec + + test("left join with optional product type"): + assume(dbType != ClickhouseDbType) + connect(xa()): + val res = sql"select * from car c left join big_dec bd on bd.id = c.id" + .query[(Car, Option[BigDec])] + .run() + assert(res.exists((_, bigDec) => bigDec.isEmpty)) +end optionalProductTests diff --git a/magnum/src/test/scala/shared/RepoTests.scala b/magnum/src/test/scala/shared/RepoTests.scala new file mode 100644 index 0000000..0079997 --- /dev/null +++ b/magnum/src/test/scala/shared/RepoTests.scala @@ -0,0 +1,389 @@ +package shared + +import com.augustnagro.magnum.* +import munit.FunSuite + +import java.time.OffsetDateTime +import java.util.UUID + +def repoTests(suite: FunSuite, dbType: DbType, xa: () => Transactor)(using + munit.Location, + DbCodec[UUID], + DbCodec[Boolean], + DbCodec[OffsetDateTime] +): Unit = + import suite.* + + @Table(dbType, SqlNameMapper.CamelToSnakeCase) + case class Person( + id: Long, + firstName: Option[String], + lastName: String, + isAdmin: Boolean, + created: OffsetDateTime, + socialId: Option[UUID] + ) derives DbCodec + + val personRepo = Repo[Person, Person, Long] + val person = TableInfo[Person, Person, Long] + + test("delete"): + connect(xa()): + val p = personRepo.findById(1L).get + personRepo.delete(p) + assert(personRepo.findById(1L) == None) + + test("delete invalid"): + connect(xa()): + personRepo.delete( + Person(999L, None, "", false, OffsetDateTime.now, None) + ) + assert(8L == personRepo.count) + + test("deleteById"): + connect(xa()): + personRepo.deleteById(1L) + personRepo.deleteById(2L) + personRepo.deleteById(1L) + assert(personRepo.findAll.size == 6) + + test("deleteAll"): + connect(xa()): + val p1 = personRepo.findById(1L).get + val p2 = p1.copy(id = 2L) + val p3 = p1.copy(id = 999L) + val expectedRowsUpdate = dbType match + case ClickhouseDbType => 3 + case _ => 2 + val res = personRepo.deleteAll(Vector(p1, p2, p3)) + assert(res == BatchUpdateResult.Success(expectedRowsUpdate)) + assert(6L == personRepo.count) + + test("deleteAllById"): + connect(xa()): + val expectedRowsUpdate = dbType match + case ClickhouseDbType => 3 + case _ => 2 + val res = personRepo.deleteAllById(Vector(1L, 2L, 1L)) + assert(res == BatchUpdateResult.Success(expectedRowsUpdate)) + assert(6L == personRepo.count) + + test("truncate"): + connect(xa()): + personRepo.truncate() + assert(personRepo.count == 0L) + + test("insert"): + connect(xa()): + personRepo.insert( + Person( + id = 9L, + firstName = Some("John"), + lastName = "Smith", + isAdmin = false, + socialId = Some(UUID.randomUUID), + created = OffsetDateTime.now + ) + ) + personRepo.insert( + Person( + id = 10L, + firstName = None, + lastName = "Prince", + isAdmin = true, + socialId = None, + created = OffsetDateTime.now + ) + ) + assert(personRepo.count == 10L) + assert(personRepo.findAll.map(_.lastName).contains("Smith")) + + test("insertReturning"): + assume(dbType != MySqlDbType) + connect(xa()): + val person = personRepo.insertReturning( + Person( + id = 9L, + firstName = Some("John"), + lastName = "Smith", + isAdmin = false, + socialId = None, + created = OffsetDateTime.now + ) + ) + assert(person.lastName == "Smith") + + test("insertAllReturning"): + assume(dbType != MySqlDbType) + assume(dbType != SqliteDbType) + connect(xa()): + val newPc = Vector( + Person( + id = 9L, + firstName = Some("Chandler"), + lastName = "Johnsored", + isAdmin = true, + socialId = Some(UUID.randomUUID()), + created = OffsetDateTime.now + ), + Person( + id = 10L, + firstName = None, + lastName = "Odysseus", + isAdmin = false, + socialId = None, + created = OffsetDateTime.now + ), + Person( + id = 11L, + firstName = Some("Jorge"), + lastName = "Masvidal", + isAdmin = true, + socialId = None, + created = OffsetDateTime.now + ) + ) + val people = personRepo.insertAllReturning(newPc) + assert(personRepo.count == 11L) + assert(people.size == 3) + assert(people.last.lastName == newPc.last.lastName) + + test("insert invalid"): + intercept[SqlException]: + connect(xa()): + val invalidP = + Person(9L, None, null, false, OffsetDateTime.now, None) + personRepo.insert(invalidP) + + test("update"): + assume(dbType != ClickhouseDbType) + connect(xa()): + val p = personRepo.findById(1L).get + val updated = p.copy(firstName = None, isAdmin = false) + personRepo.update(updated) + assert(personRepo.findById(1L).get == updated) + + test("update invalid"): + assume(dbType != ClickhouseDbType) + intercept[SqlException]: + connect(xa()): + val p = personRepo.findById(1L).get + val updated = p.copy(lastName = null) + personRepo.update(updated) + + test("insertAll"): + connect(xa()): + val newPeople = Vector( + Person( + id = 9L, + firstName = Some("Chandler"), + lastName = "Johnsored", + isAdmin = true, + socialId = Some(UUID.randomUUID()), + created = OffsetDateTime.now + ), + Person( + id = 10L, + firstName = None, + lastName = "Odysseus", + isAdmin = false, + socialId = None, + created = OffsetDateTime.now + ), + Person( + id = 11L, + firstName = Some("Jorge"), + lastName = "Masvidal", + isAdmin = true, + socialId = None, + created = OffsetDateTime.now + ) + ) + personRepo.insertAll(newPeople) + assert(personRepo.count == 11L) + assert( + personRepo.findAll.map(_.lastName).contains(newPeople.last.lastName) + ) + + test("updateAll"): + assume(dbType != ClickhouseDbType) + connect(xa()): + val newPeople = Vector( + personRepo.findById(1L).get.copy(lastName = "Peterson"), + personRepo.findById(2L).get.copy(lastName = "Moreno") + ) + val res = personRepo.updateAll(newPeople) + assert(res == BatchUpdateResult.Success(2)) + assert(personRepo.findById(1L).get == newPeople(0)) + assert(personRepo.findById(2L).get == newPeople(1)) + + test("transact"): + assume(dbType != ClickhouseDbType) + val count = transact(xa()): + val p = Person( + id = 9L, + firstName = Some("Chandler"), + lastName = "Brown", + isAdmin = false, + created = OffsetDateTime.now, + socialId = None + ) + personRepo.insert(p) + personRepo.count + assert(count == 9L) + + test("transact failed"): + assume(dbType != ClickhouseDbType) + val dataSource = xa() + val p = Person( + id = 9L, + firstName = Some("Chandler"), + lastName = "Brown", + isAdmin = false, + created = OffsetDateTime.now, + socialId = None + ) + try + transact(dataSource): + personRepo.insert(p) + throw RuntimeException() + fail("should not reach") + catch + case _: Exception => + transact(dataSource): + assert(personRepo.count == 8L) + + test("custom insert"): + connect(xa()): + val p = Person( + id = 9L, + firstName = Some("Chandler"), + lastName = "Brown", + isAdmin = false, + socialId = None, + created = OffsetDateTime.now + ) + val update = + sql"insert into $person ${person.insertColumns} values ($p)".update + assertNoDiff( + update.frag.sqlString, + "insert into person (id, first_name, last_name, is_admin, created, social_id) values (?, ?, ?, ?, ?, ?)" + ) + val rowsInserted = update.run() + assert(rowsInserted == 1) + assert(personRepo.count == 9L) + assert( + personRepo.findAll.exists(fetched => + fetched.firstName == p.firstName && + fetched.lastName == p.lastName && + fetched.isAdmin == p.isAdmin + ) + ) + + test("custom update"): + connect(xa()): + val p = Person( + id = 9L, + firstName = Some("Chandler"), + lastName = "Brown", + isAdmin = false, + socialId = Some(UUID.randomUUID()), + created = OffsetDateTime.now + ) + personRepo.insert(p) + val newIsAdmin = true + val update = + sql"update $person set ${person.isAdmin} = $newIsAdmin where ${person.id} = ${p.id}".update + assertNoDiff( + update.frag.sqlString, + "update person set is_admin = ? where id = ?" + ) + val rowsUpdated = update.run() + assert(rowsUpdated == 1) + assert(personRepo.findById(p.id).get.isAdmin == true) + + test("custom returning a single column"): + assume(dbType != ClickhouseDbType) + assume(dbType != MySqlDbType) + assume(dbType != SqliteDbType) + connect(xa()): + val personId = + if dbType == H2DbType then + sql"""insert into person (id, first_name, last_name, created, is_admin) + values (9, 'Arton', 'Senna', now(), true) + """ + .returningKeys[Long]("id") + .run() + .head + else if dbType == OracleDbType then + sql"""insert into person (id, first_name, last_name, created, is_admin) + values (9, 'Arton', 'Senna', current_timestamp, 'Y')""" + .returningKeys[Long]("id") + .run() + .head + else + sql"""insert into person (id, first_name, last_name, created, is_admin) + values (9, 'Arton', 'Senna', now(), 'Y') RETURNING id + """.returning[Long].run().head + assert(personRepo.findById(personId).get.lastName == "Senna") + + test("custom returning multiple columns"): + assume(dbType != ClickhouseDbType) + assume(dbType != MySqlDbType) + assume(dbType != SqliteDbType) + assume(dbType != OracleDbType) + connect(xa()): + val cols = + if dbType == H2DbType then + sql"""insert into person (id, first_name, last_name, created, is_admin) values + (9, 'Arton', 'Senna', now(), true), + (10, 'Demo', 'User', now(), false) + """ + .returningKeys[(Long, OffsetDateTime)]( + person.id, + person.created + ) + .run() + else + sql"""insert into person (id, first_name, last_name, created, is_admin) values + (9, 'Arton', 'Senna', now(), true), + (10, 'Demo', 'User', now(), false) + RETURNING id, created + """.returning[(Long, OffsetDateTime)].run() + val newLastNames = + cols.map((id, _) => personRepo.findById(id).get.lastName) + assert(newLastNames == Vector("Senna", "User")) + + test("custom returning with no rows updated"): + assume(dbType != ClickhouseDbType) + assume(dbType != MySqlDbType) + assume(dbType != SqliteDbType) + connect(xa()): + val personIds = + if dbType == H2DbType || dbType == OracleDbType then + sql"update person set first_name = 'xxx' where last_name = 'Not Here'" + .returningKeys[Long](ColumnNames("id", IArray(person.id))) + .run() + else + sql"update person set first_name = 'xxx' where last_name = 'Not Here' returning id" + .returning[Long] + .run() + assert(personIds.isEmpty) + + test("returning non primary key column"): + assume(dbType != ClickhouseDbType) + assume(dbType != MySqlDbType) + assume(dbType != SqliteDbType) + connect(xa()): + val personFirstNames = + if dbType == H2DbType || dbType == OracleDbType then + sql"update person set last_name = 'xxx'" + .returningKeys[String](person.firstName) + .run() + else + sql"update person set last_name = 'xxx' returning first_name" + .returning[String] + .run() + + assert(personFirstNames.nonEmpty) +end repoTests diff --git a/magnum/src/test/scala/shared/SharedTests.scala b/magnum/src/test/scala/shared/SharedTests.scala new file mode 100644 index 0000000..dce13c9 --- /dev/null +++ b/magnum/src/test/scala/shared/SharedTests.scala @@ -0,0 +1,24 @@ +package shared + +import com.augustnagro.magnum.{DbCodec, DbType, Transactor} +import munit.{FunSuite, Location} + +import java.time.OffsetDateTime +import java.util.UUID + +def sharedTests(suite: FunSuite, dbType: DbType, xa: () => Transactor)(using + Location, + DbCodec[UUID], + DbCodec[Boolean], + DbCodec[OffsetDateTime], + DbCodec[BigDecimal] +): Unit = + immutableRepoTests(suite, dbType, xa) + repoTests(suite, dbType, xa) + entityCreatorTests(suite, dbType, xa) + specTests(suite, dbType, xa) + sqlNameTests(suite, dbType, xa) + noIdTests(suite, dbType, xa) + embeddedFragTests(suite, dbType, xa) + bigDecTests(suite, dbType, xa) + optionalProductTests(suite, dbType, xa) diff --git a/magnum/src/test/scala/shared/SpecTests.scala b/magnum/src/test/scala/shared/SpecTests.scala new file mode 100644 index 0000000..4ee7450 --- /dev/null +++ b/magnum/src/test/scala/shared/SpecTests.scala @@ -0,0 +1,158 @@ +package shared + +import com.augustnagro.magnum.* +import munit.FunSuite + +import java.time.{OffsetDateTime, ZoneOffset} + +opaque type CarId = Long +object CarId: + def apply(value: Long): CarId = value + extension (opaque: CarId) def value: Long = opaque + given DbCodec[CarId] = + DbCodec.LongCodec.biMap(CarId.apply, _.value) + +def specTests(suite: FunSuite, dbType: DbType, xa: () => Transactor)(using + munit.Location, + DbCodec[OffsetDateTime] +): Unit = + import suite.* + + @Table(dbType, SqlNameMapper.CamelToSnakeCase) + case class Car( + model: String, + @Id id: Long, + topSpeed: Int, + @SqlName("vin") vinNumber: Option[Int], + color: Color, + created: OffsetDateTime + ) derives DbCodec + + val carRepo = ImmutableRepo[Car, Long] + val car = TableInfo[Car, Car, Long] + + val allCars = Vector( + Car( + model = "McLaren Senna", + id = 1L, + topSpeed = 208, + vinNumber = Some(123), + color = Color.Red, + created = OffsetDateTime.parse("2024-11-24T22:17:30.000000000Z") + ), + Car( + model = "Ferrari F8 Tributo", + id = 2L, + topSpeed = 212, + vinNumber = Some(124), + color = Color.Green, + created = OffsetDateTime.parse("2024-11-24T22:17:31.000000000Z") + ), + Car( + model = "Aston Martin Superleggera", + id = 3L, + topSpeed = 211, + vinNumber = None, + color = Color.Blue, + created = OffsetDateTime.parse("2024-11-24T22:17:32.000000000Z") + ) + ) + + test("select all"): + transact(xa()): + val spec = Spec[Car] + assert(carRepo.findAll(spec) == allCars) + + test("empty predicate"): + transact(xa()): + val spec = Spec[Car].where(sql"") + assert(carRepo.findAll(spec) == allCars) + + test("predicate having param at end"): + transact(xa()): + val id = CarId(2L) + val spec = Spec[Car].where(sql"$id < id") + assert(carRepo.findAll(spec) == Vector(allCars.last)) + + test("AND in where predicate"): + transact(xa()): + val color = Color.Red + val model = "MCLAREN SENNA" + val spec = + Spec[Car].where(sql"color = $color AND $model = upper(model)") + assert(carRepo.findAll(spec) == Vector(allCars.head)) + + test("multiple where parameters"): + transact(xa()): + val color = Color.Red + val model = "MCLAREN SENNA" + val spec = Spec[Car] + .where(sql"color = $color") + .where(sql"$model = upper(model)") + assert(carRepo.findAll(spec) == Vector(allCars.head)) + + test("orderBy"): + transact(xa()): + val spec = Spec[Car].orderBy("top_speed") + assert(carRepo.findAll(spec) == allCars.sortBy(_.topSpeed)) + + test("orderBy null with sort order and null order"): + transact(xa()): + val spec = Spec[Car] + .orderBy("vin", SortOrder.Desc, NullOrder.First) + assert(carRepo.findAll(spec) == allCars.reverse) + + test("limit"): + transact(xa()): + val spec = Spec[Car].limit(2) + assert(carRepo.findAll(spec).size == 2) + + test("offset"): + transact(xa()): + val spec = Spec[Car].offset(1) + assert(carRepo.findAll(spec) == allCars.tail) + + test("seek"): + transact(xa()): + val spec = Spec[Car].seek("id", SeekDir.Gt, 2, SortOrder.Asc) + assert(carRepo.findAll(spec).size == 1) + + test("seek multiple"): + transact(xa()): + val spec = Spec[Car] + .seek("id", SeekDir.Lt, 3, SortOrder.Asc) + .seek("top_speed", SeekDir.Gt, 210, SortOrder.Asc) + assert(carRepo.findAll(spec) == Vector(allCars(1))) + + test("everything"): + transact(xa()): + val idOpt = Option.empty[CarId] + val speed = 210 + val spec = Spec[Car] + .where(idOpt.map(id => sql"id = $id").getOrElse(sql"")) + .where(sql"top_speed > $speed") + .orderBy("model", SortOrder.Desc) + .limit(1) + .seek("vin", SeekDir.Gt, 1, SortOrder.Asc, NullOrder.Last) + assert(carRepo.findAll(spec) == Vector(allCars(1))) + + test("prefix"): + transact(xa()): + val c = car.alias("c") + val color = Color.Red + val spec = Spec[Car] + .prefix(sql"SELECT ${c.all} FROM $c") + .where(sql"${c.color} = $color") + assert(carRepo.findAll(spec) == Vector(allCars.head)) + + test("prefix with embedded sql"): + transact(xa()): + val c = car.alias("c") + val color = Color.Red + val selectPart = sql"SELECT ${c.all}" + val fromPart = sql"FROM $c" + val spec = Spec[Car] + .prefix(sql"$selectPart $fromPart") + .where(sql"${c.color} = $color") + assert(carRepo.findAll(spec) == Vector(allCars.head)) +end specTests diff --git a/magnum/src/test/scala/shared/SqlNameTests.scala b/magnum/src/test/scala/shared/SqlNameTests.scala new file mode 100644 index 0000000..0eb4fbc --- /dev/null +++ b/magnum/src/test/scala/shared/SqlNameTests.scala @@ -0,0 +1,29 @@ +package shared + +import com.augustnagro.magnum.* +import munit.FunSuite + +import java.time.OffsetDateTime + +def sqlNameTests(suite: FunSuite, dbType: DbType, xa: () => Transactor)(using + munit.Location +): Unit = + import suite.* + + @SqlName("car") + @Table(dbType, SqlNameMapper.CamelToSnakeCase) + case class CustomCar( + model: String, + @Id id: Long, + topSpeed: Int, + @SqlName("vin") vinNumber: Option[Int], + color: Color, + created: OffsetDateTime + ) derives DbCodec + + val customCarRepo = Repo[CustomCar, CustomCar, Long] + + test("count with manual table name"): + val count = connect(xa())(customCarRepo.count) + assert(count == 3L) +end sqlNameTests