From c3fefe4f9f83b8a08c45f3c960bc16145cdff24e Mon Sep 17 00:00:00 2001 From: Erlend Hamnaberg Date: Tue, 11 Jul 2023 15:22:03 +0200 Subject: [PATCH 1/2] We can use offline zeta sql to verify queries * Analyze the query before rendering it to extract referenced tables * Bump to jdk11 --- .github/workflows/ci.yml | 46 ++-- build.sbt | 20 +- .../main/scala/no/nrk/bigquery/ZetaSql.scala | 214 ++++++++++++++++++ .../test/scala/no/nrk/bigquery/ZetaTest.scala | 88 +++++++ 4 files changed, 344 insertions(+), 24 deletions(-) create mode 100644 zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala create mode 100644 zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2d7e9f60..bc223507 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,7 +24,7 @@ jobs: matrix: os: [ubuntu-latest] scala: [2.13.11, 2.12.18, 3.3.0] - java: [temurin@8] + java: [temurin@11] project: [rootJVM] runs-on: ${{ matrix.os }} steps: @@ -33,21 +33,21 @@ jobs: with: fetch-depth: 0 - - name: Download Java (temurin@8) - id: download-java-temurin-8 - if: matrix.java == 'temurin@8' + - name: Download Java (temurin@11) + id: download-java-temurin-11 + if: matrix.java == 'temurin@11' uses: typelevel/download-java@v2 with: distribution: temurin - java-version: 8 + java-version: 11 - - name: Setup Java (temurin@8) - if: matrix.java == 'temurin@8' + - name: Setup Java (temurin@11) + if: matrix.java == 'temurin@11' uses: actions/setup-java@v3 with: distribution: jdkfile - java-version: 8 - jdkFile: ${{ steps.download-java-temurin-8.outputs.jdkFile }} + java-version: 11 + jdkFile: ${{ steps.download-java-temurin-11.outputs.jdkFile }} - name: Cache sbt uses: actions/cache@v3 @@ -65,7 +65,7 @@ jobs: run: sbt githubWorkflowCheck - name: Check formatting - if: matrix.java == 'temurin@8' && matrix.os == 'ubuntu-latest' + if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest' run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' scalafmtCheckAll 'project /' scalafmtSbtCheck - name: Test @@ -75,23 +75,23 @@ jobs: run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' test - name: Check binary compatibility - if: matrix.java == 'temurin@8' && matrix.os == 'ubuntu-latest' + if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest' env: MYGET_USERNAME: ${{ secrets.PLATTFORM_MYGET_ENTERPRISE_READ_ID }} MYGET_PASSWORD: ${{ secrets.PLATTFORM_MYGET_ENTERPRISE_READ_SECRET }} run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' mimaReportBinaryIssues - name: Generate API documentation - if: matrix.java == 'temurin@8' && matrix.os == 'ubuntu-latest' + if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest' run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' doc - name: Make target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) - run: mkdir -p target .js/target site/target prometheus/.jvm/target testing/.jvm/target .jvm/target .native/target core/.jvm/target project/target + run: mkdir -p target .js/target site/target prometheus/.jvm/target testing/.jvm/target .jvm/target .native/target core/.jvm/target zetasql/.jvm/target project/target - name: Compress target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) - run: tar cf targets.tar target .js/target site/target prometheus/.jvm/target testing/.jvm/target .jvm/target .native/target core/.jvm/target project/target + run: tar cf targets.tar target .js/target site/target prometheus/.jvm/target testing/.jvm/target .jvm/target .native/target core/.jvm/target zetasql/.jvm/target project/target - name: Upload target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) @@ -107,7 +107,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - java: [temurin@8] + java: [temurin@11] runs-on: ${{ matrix.os }} steps: - name: Checkout current branch (full) @@ -115,21 +115,21 @@ jobs: with: fetch-depth: 0 - - name: Download Java (temurin@8) - id: download-java-temurin-8 - if: matrix.java == 'temurin@8' + - name: Download Java (temurin@11) + id: download-java-temurin-11 + if: matrix.java == 'temurin@11' uses: typelevel/download-java@v2 with: distribution: temurin - java-version: 8 + java-version: 11 - - name: Setup Java (temurin@8) - if: matrix.java == 'temurin@8' + - name: Setup Java (temurin@11) + if: matrix.java == 'temurin@11' uses: actions/setup-java@v3 with: distribution: jdkfile - java-version: 8 - jdkFile: ${{ steps.download-java-temurin-8.outputs.jdkFile }} + java-version: 11 + jdkFile: ${{ steps.download-java-temurin-11.outputs.jdkFile }} - name: Cache sbt uses: actions/cache@v3 diff --git a/build.sbt b/build.sbt index 86bd7fbe..df0709d9 100644 --- a/build.sbt +++ b/build.sbt @@ -59,6 +59,7 @@ ThisBuild / tlVersionIntroduced := Map( "3" -> "0.1.1", "2.13" -> "0.1.0" ) +ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.temurin("11")) val commonSettings = Seq( resolvers += "MyGet - datahub".at(s"https://nrk.myget.org/F/datahub/maven/"), @@ -94,7 +95,7 @@ val commonSettings = Seq( lazy val root = tlCrossRootProject .settings(name := "bigquery-scala") - .aggregate(core, testing, prometheus, docs) + .aggregate(core, testing, prometheus, zetasql, docs) .disablePlugins(TypelevelCiSigningPlugin, Sonatype, SbtGpg) lazy val core = crossProject(JVMPlatform) @@ -156,6 +157,23 @@ lazy val prometheus = crossProject(JVMPlatform) ) .disablePlugins(TypelevelCiSigningPlugin, Sonatype, SbtGpg) +lazy val zetasql = crossProject(JVMPlatform) + .withoutSuffixFor(JVMPlatform) + .crossType(CrossType.Pure) + .in(file("zetasql")) + .settings(commonSettings) + .dependsOn(core) + .settings( + name := "bigquery-zetasql", + tlMimaPreviousVersions := Set.empty, + libraryDependencies ++= Seq( + "com.google.zetasql.toolkit" % "zetasql-toolkit-bigquery" % "0.4.0", + "org.scalameta" %% "munit" % "0.7.29", + "org.typelevel" %% "munit-cats-effect-3" % "1.0.7" + ) + ) + .disablePlugins(TypelevelCiSigningPlugin, Sonatype, SbtGpg) + lazy val testing = crossProject(JVMPlatform) .withoutSuffixFor(JVMPlatform) .crossType(CrossType.Pure) diff --git a/zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala b/zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala new file mode 100644 index 00000000..c9906681 --- /dev/null +++ b/zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala @@ -0,0 +1,214 @@ +package no.nrk.bigquery + +import cats.syntax.all._ +import cats.effect.IO +import no.nrk.bigquery.syntax._ +import com.google.zetasql.{ + AnalyzerOptions, + ParseLocationRange, + Parser, + SimpleColumn, + SimpleTable, + SqlException, + StructType, + Type, + TypeFactory +} +import com.google.zetasql.ZetaSQLType.TypeKind +import com.google.zetasql.resolvedast.ResolvedCreateStatementEnums.{CreateMode, CreateScope} +import com.google.zetasql.resolvedast.ResolvedNodes +import com.google.zetasql.toolkit.catalog.basic.BasicCatalogWrapper +import com.google.zetasql.toolkit.options.BigQueryLanguageOptions +import com.google.zetasql.parser.{ASTNodes, ParseTreeVisitor} +import com.google.zetasql.toolkit.{AnalysisException, AnalyzedStatement, ZetaSQLToolkitAnalyzer} + +import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters._ +import scala.jdk.OptionConverters._ + +object ZetaSql { + def parse(frag: BQSqlFrag): IO[Either[SqlException, BQSqlFrag]] = parseScript(frag).map(_.as(frag)) + + def parseScript(frag: BQSqlFrag): IO[Either[SqlException, ASTNodes.ASTScript]] = IO.interruptible { + val options = BigQueryLanguageOptions.get() + + try + Right(Parser.parseScript(frag.asString, options)) + catch { + case e: SqlException => Left(e) // only catch sql exception and let all others bubble up to IO + } + } + + def parseAndBuildAnalysableFragment( + query: String, + allTables: List[BQTableLike[Any]], + toFragment: BQTableLike[Any] => BQSqlFrag = _.unpartitioned.bqShow, + eqv: (BQTableId, BQTableId) => Boolean = _ == _): IO[BQSqlFrag] = { + + def evalFragments( + parsedTables: List[(BQTableId, ParseLocationRange)] + ): BQSqlFrag = { + val asString = query + val found = allTables + .flatMap(table => + parsedTables.flatMap { case (id, range) => if (eqv(table.tableId, id)) List(table -> range) else Nil }) + .distinct + val (rest, aggregate) = found.foldLeft((asString, BQSqlFrag.Empty)) { case ((input, agg), (t, loc)) => + val frag = agg ++ BQSqlFrag.Frag(input.substring(0, loc.start() - 1)) ++ toFragment(t) + val rest = input.substring(loc.end()) + rest -> frag + } + + aggregate ++ BQSqlFrag.Frag(rest) + } + + parseScript(BQSqlFrag.Frag(query)) + .flatMap(IO.fromEither) + .flatMap { script => + val list = script.getStatementListNode.getStatementList + if (list.size() != 1) { + IO.raiseError(new IllegalArgumentException("Expects only one statement")) + } else + IO { + val buffer = new ListBuffer[(BQTableId, ParseLocationRange)] + list.asScala.headOption.foreach(_.accept(new ParseTreeVisitor { + override def visit(node: ASTNodes.ASTTablePathExpression): Unit = + node.getPathExpr.getNames.forEach(ident => + BQTableId + .fromString(ident.getIdString) + .toOption + .foreach(id => buffer += (id -> ident.getParseLocationRange))) + })) + buffer.toList + } + } + .map(evalFragments) + } + + def queryFields(frag: BQSqlFrag): IO[List[BQField]] = + analyzeFirst(frag).map { res => + val builder = List.newBuilder[BQField] + + res + .flatMap(_.getResolvedStatement.toScala.toRight(new AnalysisException("No analysis found"))) + .foreach(tree => + tree.accept(new ResolvedNodes.Visitor { + override def visit(node: ResolvedNodes.ResolvedQueryStmt): Unit = + node.getOutputColumnList.asScala.foreach(col => + builder += ZetaSql.fromColumnNameAndType(col.getColumn.getName, col.getColumn.getType)) + })) + builder.result() + } + + def analyzeFirst(frag: BQSqlFrag): IO[Either[AnalysisException, AnalyzedStatement]] = IO.interruptible { + val tables = frag.allReferencedTables + val catalog = toCatalog(tables: _*) + val rendered = frag.asString + + val options = BigQueryLanguageOptions.get() + val analyzerOptions = new AnalyzerOptions + analyzerOptions.setLanguageOptions(options) + analyzerOptions.setPreserveColumnAliases(true) + + val analyser = new ZetaSQLToolkitAnalyzer(analyzerOptions) + val analyzed = analyser.analyzeStatements(rendered, catalog) + + if (analyzed.hasNext) + Right(analyzed.next()) + else Left(new AnalysisException("Unable to find any analyzed statements")) + } + + def toCatalog(tables: BQTableLike[Any]*): BasicCatalogWrapper = { + val catalog = new BasicCatalogWrapper() + tables.foreach(table => + catalog.register(toSimpleTable(table), CreateMode.CREATE_IF_NOT_EXISTS, CreateScope.CREATE_DEFAULT_SCOPE)) + catalog + } + + def fromColumnNameAndType(name: String, typ: Type): BQField = { + val kind = typ.getKind match { + case TypeKind.TYPE_BOOL => BQField.Type.BOOL + case TypeKind.TYPE_DATE => BQField.Type.DATE + case TypeKind.TYPE_DATETIME => BQField.Type.DATETIME + case TypeKind.TYPE_JSON => BQField.Type.JSON + case TypeKind.TYPE_BYTES => BQField.Type.BYTES + case TypeKind.TYPE_STRING => BQField.Type.STRING + case TypeKind.TYPE_BIGNUMERIC => BQField.Type.BIGNUMERIC + case TypeKind.TYPE_INT64 => BQField.Type.INT64 + case TypeKind.TYPE_INT32 => BQField.Type.INT64 + case TypeKind.TYPE_FLOAT => BQField.Type.FLOAT64 + case TypeKind.TYPE_DOUBLE => BQField.Type.FLOAT64 + case TypeKind.TYPE_TIMESTAMP => BQField.Type.TIMESTAMP + case TypeKind.TYPE_TIME => BQField.Type.TIME + case TypeKind.TYPE_GEOGRAPHY => BQField.Type.GEOGRAPHY + case TypeKind.TYPE_INTERVAL => BQField.Type.INTERVAL + case _ => throw new IllegalArgumentException(s"$name with type ${typ.debugString()} is not supported ") + } + + if (typ.isArray) { + val elem = fromColumnNameAndType(name, typ.asArray().getElementType) + elem.copy(mode = BQField.Mode.REPEATED) + } else if (typ.isStruct) { + BQField.struct(name, BQField.Mode.NULLABLE)( + typ + .asStruct() + .getFieldList + .asScala + .map(subField => fromColumnNameAndType(subField.getName, subField.getType)) + .toList: _*) + } else BQField(name, kind, BQField.Mode.NULLABLE) + } + + def toSimpleTable(table: BQTableLike[Any]): SimpleTable = { + def toType(field: BQField): Type = { + val isArray = field.mode == BQField.Mode.REPEATED + + val elemType = field.tpe match { + case BQField.Type.BOOL => TypeFactory.createSimpleType(TypeKind.TYPE_BOOL) + case BQField.Type.INT64 => TypeFactory.createSimpleType(TypeKind.TYPE_INT64) + case BQField.Type.FLOAT64 => TypeFactory.createSimpleType(TypeKind.TYPE_FLOAT) + case BQField.Type.NUMERIC => TypeFactory.createSimpleType(TypeKind.TYPE_NUMERIC) + case BQField.Type.BIGNUMERIC => TypeFactory.createSimpleType(TypeKind.TYPE_BIGNUMERIC) + case BQField.Type.STRING => TypeFactory.createSimpleType(TypeKind.TYPE_STRING) + case BQField.Type.BYTES => TypeFactory.createSimpleType(TypeKind.TYPE_BYTES) + case BQField.Type.STRUCT => + TypeFactory.createStructType( + field.subFields + .map(sub => new StructType.StructField(sub.name, toType(sub))) + .asJavaCollection + ) + case BQField.Type.ARRAY => + TypeFactory.createArrayType(toType(field.subFields.head)) + case BQField.Type.TIMESTAMP => TypeFactory.createSimpleType(TypeKind.TYPE_TIMESTAMP) + case BQField.Type.DATE => TypeFactory.createSimpleType(TypeKind.TYPE_DATE) + case BQField.Type.TIME => TypeFactory.createSimpleType(TypeKind.TYPE_TIME) + case BQField.Type.DATETIME => TypeFactory.createSimpleType(TypeKind.TYPE_DATETIME) + case BQField.Type.GEOGRAPHY => TypeFactory.createSimpleType(TypeKind.TYPE_GEOGRAPHY) + case BQField.Type.JSON => TypeFactory.createSimpleType(TypeKind.TYPE_JSON) + case BQField.Type.INTERVAL => TypeFactory.createSimpleType(TypeKind.TYPE_INTERVAL) + } + if (isArray) TypeFactory.createArrayType(elemType) else elemType + } + + def toSimpleField(field: BQField) = + new SimpleColumn(table.tableId.tableName, field.name, toType(field), false, true) + val simple = table match { + case BQTableRef(tableId, _, _) => + new SimpleTable(tableId.tableName) + + case tbl: BQTableDef.Table[_] => + new SimpleTable( + tbl.tableId.tableName, + new java.util.ArrayList(tbl.schema.fields.map(toSimpleField).asJavaCollection) + ) + case view: BQTableDef.ViewLike[_] => + new SimpleTable( + view.tableId.tableName, + new java.util.ArrayList(view.schema.fields.map(toSimpleField).asJavaCollection) + ) + } + simple.setFullName(table.tableId.asString) + simple + } + +} diff --git a/zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala b/zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala new file mode 100644 index 00000000..f9e17b94 --- /dev/null +++ b/zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala @@ -0,0 +1,88 @@ +package no.nrk.bigquery + +import cats.effect.IO +import no.nrk.bigquery.syntax._ +import com.google.zetasql.toolkit.AnalysisException + +import java.time.LocalDate + +class ZetaTest extends munit.CatsEffectSuite { + private val table = BQTableDef.Table( + BQTableId.unsafeOf(BQDataset.unsafeOf(ProjectId("com-example"), "example"), "test"), + BQSchema.of( + BQField("partitionDate", BQField.Type.DATE, BQField.Mode.REQUIRED), + BQField("a", BQField.Type.STRING, BQField.Mode.REQUIRED), + BQField("b", BQField.Type.INT64, BQField.Mode.REQUIRED), + BQField("c", BQField.Type.INT64, BQField.Mode.REQUIRED), + BQField("d", BQField.Type.INT64, BQField.Mode.REQUIRED) + ), + BQPartitionType.DatePartitioned(Ident("partitionDate")) + ) + + test("parses select 1") { + ZetaSql.analyzeFirst(bqsql"select 1").map(_.isRight).assertEquals(true) + } + + test("fails to parse select from foo") { + ZetaSql.analyzeFirst(bqsql"select from foo").flatMap(IO.fromEither).intercept[AnalysisException] + } + + test("subset select from example") { + val date = LocalDate.of(2023, 1, 1) + + val query = bqsql"select partitionDate, a, b, c from ${table.assertPartition(date)}" + + val expected = table.schema.fields.dropRight(1).map(_.recursivelyNullable.withoutDescription) + ZetaSql.queryFields(query).assertEquals(expected) + } + + test("all fields should be selected from example") { + val date = LocalDate.of(2023, 1, 1) + + val query = bqsql"select partitionDate, a, b, c, d from ${table.assertPartition(date)}" + + val expected = table.schema.fields.map(_.recursivelyNullable.withoutDescription) + ZetaSql.queryFields(query).assertEquals(expected) + } + + test("CTE selections") { + val query = + bqsql"""|with data as ( + | select partitionDate, a, b, c from ${table.unpartitioned} + |), + | grouped as ( + | select partitionDate, a, b, COUNTIF(c is null) as nullableCs from data + | group by 1, 2, 3 + | ) + |select * from grouped + |""".stripMargin + + val expected = + (table.schema.fields.dropRight(2) ++ List(BQField("nullableCs", BQField.Type.INT64, BQField.Mode.NULLABLE))) + .map(_.recursivelyNullable.withoutDescription) + + ZetaSql.queryFields(query).assertEquals(expected) + } + + test("parse then build analysis") { + val query = + """|with data as ( + | select partitionDate, a, b, c from `com-example.example.test` + |), + | grouped as ( + | select partitionDate, a, b, COUNTIF(c is null) as nullableCs from data + | group by 1, 2, 3 + | ) + |select * from grouped + |""".stripMargin + + val expected = + (table.schema.fields.dropRight(2) ++ List(BQField("nullableCs", BQField.Type.INT64, BQField.Mode.NULLABLE))) + .map(_.recursivelyNullable.withoutDescription) + + ZetaSql + .parseAndBuildAnalysableFragment(query, List(table)) + .flatMap(ZetaSql.queryFields) + .assertEquals(expected) + } +} From 7d880a9b42f62f283ed011b416e611d820f80d63 Mon Sep 17 00:00:00 2001 From: Erlend Hamnaberg Date: Fri, 28 Jul 2023 09:44:01 +0200 Subject: [PATCH 2/2] use tagless final --- .../main/scala/no/nrk/bigquery/ZetaSql.scala | 70 +++++++++++-------- .../test/scala/no/nrk/bigquery/ZetaTest.scala | 16 +++-- 2 files changed, 48 insertions(+), 38 deletions(-) diff --git a/zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala b/zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala index c9906681..0bc68a10 100644 --- a/zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala +++ b/zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala @@ -1,7 +1,7 @@ package no.nrk.bigquery import cats.syntax.all._ -import cats.effect.IO +import cats.effect.Sync import no.nrk.bigquery.syntax._ import com.google.zetasql.{ AnalyzerOptions, @@ -26,24 +26,26 @@ import scala.collection.mutable.ListBuffer import scala.jdk.CollectionConverters._ import scala.jdk.OptionConverters._ -object ZetaSql { - def parse(frag: BQSqlFrag): IO[Either[SqlException, BQSqlFrag]] = parseScript(frag).map(_.as(frag)) +class ZetaSql[F[_]](implicit F: Sync[F]) { + import ZetaSql._ + def parse(frag: BQSqlFrag): F[Either[SqlException, BQSqlFrag]] = parseScript(frag).map(_.as(frag)) - def parseScript(frag: BQSqlFrag): IO[Either[SqlException, ASTNodes.ASTScript]] = IO.interruptible { - val options = BigQueryLanguageOptions.get() + def parseScript(frag: BQSqlFrag): F[Either[SqlException, ASTNodes.ASTScript]] = + F.interruptible { + val options = BigQueryLanguageOptions.get() - try - Right(Parser.parseScript(frag.asString, options)) - catch { - case e: SqlException => Left(e) // only catch sql exception and let all others bubble up to IO + try + Right(Parser.parseScript(frag.asString, options)) + catch { + case e: SqlException => Left(e) // only catch sql exception and let all others bubble up to IO + } } - } def parseAndBuildAnalysableFragment( query: String, allTables: List[BQTableLike[Any]], toFragment: BQTableLike[Any] => BQSqlFrag = _.unpartitioned.bqShow, - eqv: (BQTableId, BQTableId) => Boolean = _ == _): IO[BQSqlFrag] = { + eqv: (BQTableId, BQTableId) => Boolean = _ == _): F[BQSqlFrag] = { def evalFragments( parsedTables: List[(BQTableId, ParseLocationRange)] @@ -63,13 +65,14 @@ object ZetaSql { } parseScript(BQSqlFrag.Frag(query)) - .flatMap(IO.fromEither) + .flatMap(F.fromEither) .flatMap { script => val list = script.getStatementListNode.getStatementList if (list.size() != 1) { - IO.raiseError(new IllegalArgumentException("Expects only one statement")) + Sync[F].raiseError[List[(no.nrk.bigquery.BQTableId, com.google.zetasql.ParseLocationRange)]]( + new IllegalArgumentException("Expects only one statement")) } else - IO { + Sync[F].delay { val buffer = new ListBuffer[(BQTableId, ParseLocationRange)] list.asScala.headOption.foreach(_.accept(new ParseTreeVisitor { override def visit(node: ASTNodes.ASTTablePathExpression): Unit = @@ -85,7 +88,7 @@ object ZetaSql { .map(evalFragments) } - def queryFields(frag: BQSqlFrag): IO[List[BQField]] = + def queryFields(frag: BQSqlFrag): F[List[BQField]] = analyzeFirst(frag).map { res => val builder = List.newBuilder[BQField] @@ -95,28 +98,33 @@ object ZetaSql { tree.accept(new ResolvedNodes.Visitor { override def visit(node: ResolvedNodes.ResolvedQueryStmt): Unit = node.getOutputColumnList.asScala.foreach(col => - builder += ZetaSql.fromColumnNameAndType(col.getColumn.getName, col.getColumn.getType)) + builder += fromColumnNameAndType(col.getColumn.getName, col.getColumn.getType)) })) builder.result() } - def analyzeFirst(frag: BQSqlFrag): IO[Either[AnalysisException, AnalyzedStatement]] = IO.interruptible { - val tables = frag.allReferencedTables - val catalog = toCatalog(tables: _*) - val rendered = frag.asString + def analyzeFirst(frag: BQSqlFrag): F[Either[AnalysisException, AnalyzedStatement]] = + F.interruptible { + val tables = frag.allReferencedTables + val catalog = toCatalog(tables: _*) + val rendered = frag.asString - val options = BigQueryLanguageOptions.get() - val analyzerOptions = new AnalyzerOptions - analyzerOptions.setLanguageOptions(options) - analyzerOptions.setPreserveColumnAliases(true) + val options = BigQueryLanguageOptions.get() + val analyzerOptions = new AnalyzerOptions + analyzerOptions.setLanguageOptions(options) + analyzerOptions.setPreserveColumnAliases(true) - val analyser = new ZetaSQLToolkitAnalyzer(analyzerOptions) - val analyzed = analyser.analyzeStatements(rendered, catalog) + val analyser = new ZetaSQLToolkitAnalyzer(analyzerOptions) + val analyzed = analyser.analyzeStatements(rendered, catalog) - if (analyzed.hasNext) - Right(analyzed.next()) - else Left(new AnalysisException("Unable to find any analyzed statements")) - } + if (analyzed.hasNext) + Right(analyzed.next()) + else Left(new AnalysisException("Unable to find any analyzed statements")) + } +} + +object ZetaSql { + def apply[F[_]: Sync]: ZetaSql[F] = new ZetaSql[F] def toCatalog(tables: BQTableLike[Any]*): BasicCatalogWrapper = { val catalog = new BasicCatalogWrapper() @@ -192,6 +200,7 @@ object ZetaSql { def toSimpleField(field: BQField) = new SimpleColumn(table.tableId.tableName, field.name, toType(field), false, true) + val simple = table match { case BQTableRef(tableId, _, _) => new SimpleTable(tableId.tableName) @@ -210,5 +219,4 @@ object ZetaSql { simple.setFullName(table.tableId.asString) simple } - } diff --git a/zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala b/zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala index f9e17b94..2443f630 100644 --- a/zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala +++ b/zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala @@ -7,6 +7,8 @@ import com.google.zetasql.toolkit.AnalysisException import java.time.LocalDate class ZetaTest extends munit.CatsEffectSuite { + private val zetaSql = new ZetaSql[IO] + private val table = BQTableDef.Table( BQTableId.unsafeOf(BQDataset.unsafeOf(ProjectId("com-example"), "example"), "test"), BQSchema.of( @@ -20,11 +22,11 @@ class ZetaTest extends munit.CatsEffectSuite { ) test("parses select 1") { - ZetaSql.analyzeFirst(bqsql"select 1").map(_.isRight).assertEquals(true) + zetaSql.analyzeFirst(bqsql"select 1").map(_.isRight).assertEquals(true) } test("fails to parse select from foo") { - ZetaSql.analyzeFirst(bqsql"select from foo").flatMap(IO.fromEither).intercept[AnalysisException] + zetaSql.analyzeFirst(bqsql"select from foo").flatMap(IO.fromEither).intercept[AnalysisException] } test("subset select from example") { @@ -33,7 +35,7 @@ class ZetaTest extends munit.CatsEffectSuite { val query = bqsql"select partitionDate, a, b, c from ${table.assertPartition(date)}" val expected = table.schema.fields.dropRight(1).map(_.recursivelyNullable.withoutDescription) - ZetaSql.queryFields(query).assertEquals(expected) + zetaSql.queryFields(query).assertEquals(expected) } test("all fields should be selected from example") { @@ -42,7 +44,7 @@ class ZetaTest extends munit.CatsEffectSuite { val query = bqsql"select partitionDate, a, b, c, d from ${table.assertPartition(date)}" val expected = table.schema.fields.map(_.recursivelyNullable.withoutDescription) - ZetaSql.queryFields(query).assertEquals(expected) + zetaSql.queryFields(query).assertEquals(expected) } test("CTE selections") { @@ -61,7 +63,7 @@ class ZetaTest extends munit.CatsEffectSuite { (table.schema.fields.dropRight(2) ++ List(BQField("nullableCs", BQField.Type.INT64, BQField.Mode.NULLABLE))) .map(_.recursivelyNullable.withoutDescription) - ZetaSql.queryFields(query).assertEquals(expected) + zetaSql.queryFields(query).assertEquals(expected) } test("parse then build analysis") { @@ -80,9 +82,9 @@ class ZetaTest extends munit.CatsEffectSuite { (table.schema.fields.dropRight(2) ++ List(BQField("nullableCs", BQField.Type.INT64, BQField.Mode.NULLABLE))) .map(_.recursivelyNullable.withoutDescription) - ZetaSql + zetaSql .parseAndBuildAnalysableFragment(query, List(table)) - .flatMap(ZetaSql.queryFields) + .flatMap(zetaSql.queryFields) .assertEquals(expected) } }