diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 2d7e9f60..bc223507 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -24,7 +24,7 @@ jobs: matrix: os: [ubuntu-latest] scala: [2.13.11, 2.12.18, 3.3.0] - java: [temurin@8] + java: [temurin@11] project: [rootJVM] runs-on: ${{ matrix.os }} steps: @@ -33,21 +33,21 @@ jobs: with: fetch-depth: 0 - - name: Download Java (temurin@8) - id: download-java-temurin-8 - if: matrix.java == 'temurin@8' + - name: Download Java (temurin@11) + id: download-java-temurin-11 + if: matrix.java == 'temurin@11' uses: typelevel/download-java@v2 with: distribution: temurin - java-version: 8 + java-version: 11 - - name: Setup Java (temurin@8) - if: matrix.java == 'temurin@8' + - name: Setup Java (temurin@11) + if: matrix.java == 'temurin@11' uses: actions/setup-java@v3 with: distribution: jdkfile - java-version: 8 - jdkFile: ${{ steps.download-java-temurin-8.outputs.jdkFile }} + java-version: 11 + jdkFile: ${{ steps.download-java-temurin-11.outputs.jdkFile }} - name: Cache sbt uses: actions/cache@v3 @@ -65,7 +65,7 @@ jobs: run: sbt githubWorkflowCheck - name: Check formatting - if: matrix.java == 'temurin@8' && matrix.os == 'ubuntu-latest' + if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest' run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' scalafmtCheckAll 'project /' scalafmtSbtCheck - name: Test @@ -75,23 +75,23 @@ jobs: run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' test - name: Check binary compatibility - if: matrix.java == 'temurin@8' && matrix.os == 'ubuntu-latest' + if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest' env: MYGET_USERNAME: ${{ secrets.PLATTFORM_MYGET_ENTERPRISE_READ_ID }} MYGET_PASSWORD: ${{ secrets.PLATTFORM_MYGET_ENTERPRISE_READ_SECRET }} run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' mimaReportBinaryIssues - name: Generate API documentation - if: matrix.java == 'temurin@8' && matrix.os == 'ubuntu-latest' + if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest' run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' doc - name: Make target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) - run: mkdir -p target .js/target site/target prometheus/.jvm/target testing/.jvm/target .jvm/target .native/target core/.jvm/target project/target + run: mkdir -p target .js/target site/target prometheus/.jvm/target testing/.jvm/target .jvm/target .native/target core/.jvm/target zetasql/.jvm/target project/target - name: Compress target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) - run: tar cf targets.tar target .js/target site/target prometheus/.jvm/target testing/.jvm/target .jvm/target .native/target core/.jvm/target project/target + run: tar cf targets.tar target .js/target site/target prometheus/.jvm/target testing/.jvm/target .jvm/target .native/target core/.jvm/target zetasql/.jvm/target project/target - name: Upload target directories if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v')) @@ -107,7 +107,7 @@ jobs: strategy: matrix: os: [ubuntu-latest] - java: [temurin@8] + java: [temurin@11] runs-on: ${{ matrix.os }} steps: - name: Checkout current branch (full) @@ -115,21 +115,21 @@ jobs: with: fetch-depth: 0 - - name: Download Java (temurin@8) - id: download-java-temurin-8 - if: matrix.java == 'temurin@8' + - name: Download Java (temurin@11) + id: download-java-temurin-11 + if: matrix.java == 'temurin@11' uses: typelevel/download-java@v2 with: distribution: temurin - java-version: 8 + java-version: 11 - - name: Setup Java (temurin@8) - if: matrix.java == 'temurin@8' + - name: Setup Java (temurin@11) + if: matrix.java == 'temurin@11' uses: actions/setup-java@v3 with: distribution: jdkfile - java-version: 8 - jdkFile: ${{ steps.download-java-temurin-8.outputs.jdkFile }} + java-version: 11 + jdkFile: ${{ steps.download-java-temurin-11.outputs.jdkFile }} - name: Cache sbt uses: actions/cache@v3 diff --git a/build.sbt b/build.sbt index 86bd7fbe..df0709d9 100644 --- a/build.sbt +++ b/build.sbt @@ -59,6 +59,7 @@ ThisBuild / tlVersionIntroduced := Map( "3" -> "0.1.1", "2.13" -> "0.1.0" ) +ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.temurin("11")) val commonSettings = Seq( resolvers += "MyGet - datahub".at(s"https://nrk.myget.org/F/datahub/maven/"), @@ -94,7 +95,7 @@ val commonSettings = Seq( lazy val root = tlCrossRootProject .settings(name := "bigquery-scala") - .aggregate(core, testing, prometheus, docs) + .aggregate(core, testing, prometheus, zetasql, docs) .disablePlugins(TypelevelCiSigningPlugin, Sonatype, SbtGpg) lazy val core = crossProject(JVMPlatform) @@ -156,6 +157,23 @@ lazy val prometheus = crossProject(JVMPlatform) ) .disablePlugins(TypelevelCiSigningPlugin, Sonatype, SbtGpg) +lazy val zetasql = crossProject(JVMPlatform) + .withoutSuffixFor(JVMPlatform) + .crossType(CrossType.Pure) + .in(file("zetasql")) + .settings(commonSettings) + .dependsOn(core) + .settings( + name := "bigquery-zetasql", + tlMimaPreviousVersions := Set.empty, + libraryDependencies ++= Seq( + "com.google.zetasql.toolkit" % "zetasql-toolkit-bigquery" % "0.4.0", + "org.scalameta" %% "munit" % "0.7.29", + "org.typelevel" %% "munit-cats-effect-3" % "1.0.7" + ) + ) + .disablePlugins(TypelevelCiSigningPlugin, Sonatype, SbtGpg) + lazy val testing = crossProject(JVMPlatform) .withoutSuffixFor(JVMPlatform) .crossType(CrossType.Pure) diff --git a/zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala b/zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala new file mode 100644 index 00000000..0bc68a10 --- /dev/null +++ b/zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala @@ -0,0 +1,222 @@ +package no.nrk.bigquery + +import cats.syntax.all._ +import cats.effect.Sync +import no.nrk.bigquery.syntax._ +import com.google.zetasql.{ + AnalyzerOptions, + ParseLocationRange, + Parser, + SimpleColumn, + SimpleTable, + SqlException, + StructType, + Type, + TypeFactory +} +import com.google.zetasql.ZetaSQLType.TypeKind +import com.google.zetasql.resolvedast.ResolvedCreateStatementEnums.{CreateMode, CreateScope} +import com.google.zetasql.resolvedast.ResolvedNodes +import com.google.zetasql.toolkit.catalog.basic.BasicCatalogWrapper +import com.google.zetasql.toolkit.options.BigQueryLanguageOptions +import com.google.zetasql.parser.{ASTNodes, ParseTreeVisitor} +import com.google.zetasql.toolkit.{AnalysisException, AnalyzedStatement, ZetaSQLToolkitAnalyzer} + +import scala.collection.mutable.ListBuffer +import scala.jdk.CollectionConverters._ +import scala.jdk.OptionConverters._ + +class ZetaSql[F[_]](implicit F: Sync[F]) { + import ZetaSql._ + def parse(frag: BQSqlFrag): F[Either[SqlException, BQSqlFrag]] = parseScript(frag).map(_.as(frag)) + + def parseScript(frag: BQSqlFrag): F[Either[SqlException, ASTNodes.ASTScript]] = + F.interruptible { + val options = BigQueryLanguageOptions.get() + + try + Right(Parser.parseScript(frag.asString, options)) + catch { + case e: SqlException => Left(e) // only catch sql exception and let all others bubble up to IO + } + } + + def parseAndBuildAnalysableFragment( + query: String, + allTables: List[BQTableLike[Any]], + toFragment: BQTableLike[Any] => BQSqlFrag = _.unpartitioned.bqShow, + eqv: (BQTableId, BQTableId) => Boolean = _ == _): F[BQSqlFrag] = { + + def evalFragments( + parsedTables: List[(BQTableId, ParseLocationRange)] + ): BQSqlFrag = { + val asString = query + val found = allTables + .flatMap(table => + parsedTables.flatMap { case (id, range) => if (eqv(table.tableId, id)) List(table -> range) else Nil }) + .distinct + val (rest, aggregate) = found.foldLeft((asString, BQSqlFrag.Empty)) { case ((input, agg), (t, loc)) => + val frag = agg ++ BQSqlFrag.Frag(input.substring(0, loc.start() - 1)) ++ toFragment(t) + val rest = input.substring(loc.end()) + rest -> frag + } + + aggregate ++ BQSqlFrag.Frag(rest) + } + + parseScript(BQSqlFrag.Frag(query)) + .flatMap(F.fromEither) + .flatMap { script => + val list = script.getStatementListNode.getStatementList + if (list.size() != 1) { + Sync[F].raiseError[List[(no.nrk.bigquery.BQTableId, com.google.zetasql.ParseLocationRange)]]( + new IllegalArgumentException("Expects only one statement")) + } else + Sync[F].delay { + val buffer = new ListBuffer[(BQTableId, ParseLocationRange)] + list.asScala.headOption.foreach(_.accept(new ParseTreeVisitor { + override def visit(node: ASTNodes.ASTTablePathExpression): Unit = + node.getPathExpr.getNames.forEach(ident => + BQTableId + .fromString(ident.getIdString) + .toOption + .foreach(id => buffer += (id -> ident.getParseLocationRange))) + })) + buffer.toList + } + } + .map(evalFragments) + } + + def queryFields(frag: BQSqlFrag): F[List[BQField]] = + analyzeFirst(frag).map { res => + val builder = List.newBuilder[BQField] + + res + .flatMap(_.getResolvedStatement.toScala.toRight(new AnalysisException("No analysis found"))) + .foreach(tree => + tree.accept(new ResolvedNodes.Visitor { + override def visit(node: ResolvedNodes.ResolvedQueryStmt): Unit = + node.getOutputColumnList.asScala.foreach(col => + builder += fromColumnNameAndType(col.getColumn.getName, col.getColumn.getType)) + })) + builder.result() + } + + def analyzeFirst(frag: BQSqlFrag): F[Either[AnalysisException, AnalyzedStatement]] = + F.interruptible { + val tables = frag.allReferencedTables + val catalog = toCatalog(tables: _*) + val rendered = frag.asString + + val options = BigQueryLanguageOptions.get() + val analyzerOptions = new AnalyzerOptions + analyzerOptions.setLanguageOptions(options) + analyzerOptions.setPreserveColumnAliases(true) + + val analyser = new ZetaSQLToolkitAnalyzer(analyzerOptions) + val analyzed = analyser.analyzeStatements(rendered, catalog) + + if (analyzed.hasNext) + Right(analyzed.next()) + else Left(new AnalysisException("Unable to find any analyzed statements")) + } +} + +object ZetaSql { + def apply[F[_]: Sync]: ZetaSql[F] = new ZetaSql[F] + + def toCatalog(tables: BQTableLike[Any]*): BasicCatalogWrapper = { + val catalog = new BasicCatalogWrapper() + tables.foreach(table => + catalog.register(toSimpleTable(table), CreateMode.CREATE_IF_NOT_EXISTS, CreateScope.CREATE_DEFAULT_SCOPE)) + catalog + } + + def fromColumnNameAndType(name: String, typ: Type): BQField = { + val kind = typ.getKind match { + case TypeKind.TYPE_BOOL => BQField.Type.BOOL + case TypeKind.TYPE_DATE => BQField.Type.DATE + case TypeKind.TYPE_DATETIME => BQField.Type.DATETIME + case TypeKind.TYPE_JSON => BQField.Type.JSON + case TypeKind.TYPE_BYTES => BQField.Type.BYTES + case TypeKind.TYPE_STRING => BQField.Type.STRING + case TypeKind.TYPE_BIGNUMERIC => BQField.Type.BIGNUMERIC + case TypeKind.TYPE_INT64 => BQField.Type.INT64 + case TypeKind.TYPE_INT32 => BQField.Type.INT64 + case TypeKind.TYPE_FLOAT => BQField.Type.FLOAT64 + case TypeKind.TYPE_DOUBLE => BQField.Type.FLOAT64 + case TypeKind.TYPE_TIMESTAMP => BQField.Type.TIMESTAMP + case TypeKind.TYPE_TIME => BQField.Type.TIME + case TypeKind.TYPE_GEOGRAPHY => BQField.Type.GEOGRAPHY + case TypeKind.TYPE_INTERVAL => BQField.Type.INTERVAL + case _ => throw new IllegalArgumentException(s"$name with type ${typ.debugString()} is not supported ") + } + + if (typ.isArray) { + val elem = fromColumnNameAndType(name, typ.asArray().getElementType) + elem.copy(mode = BQField.Mode.REPEATED) + } else if (typ.isStruct) { + BQField.struct(name, BQField.Mode.NULLABLE)( + typ + .asStruct() + .getFieldList + .asScala + .map(subField => fromColumnNameAndType(subField.getName, subField.getType)) + .toList: _*) + } else BQField(name, kind, BQField.Mode.NULLABLE) + } + + def toSimpleTable(table: BQTableLike[Any]): SimpleTable = { + def toType(field: BQField): Type = { + val isArray = field.mode == BQField.Mode.REPEATED + + val elemType = field.tpe match { + case BQField.Type.BOOL => TypeFactory.createSimpleType(TypeKind.TYPE_BOOL) + case BQField.Type.INT64 => TypeFactory.createSimpleType(TypeKind.TYPE_INT64) + case BQField.Type.FLOAT64 => TypeFactory.createSimpleType(TypeKind.TYPE_FLOAT) + case BQField.Type.NUMERIC => TypeFactory.createSimpleType(TypeKind.TYPE_NUMERIC) + case BQField.Type.BIGNUMERIC => TypeFactory.createSimpleType(TypeKind.TYPE_BIGNUMERIC) + case BQField.Type.STRING => TypeFactory.createSimpleType(TypeKind.TYPE_STRING) + case BQField.Type.BYTES => TypeFactory.createSimpleType(TypeKind.TYPE_BYTES) + case BQField.Type.STRUCT => + TypeFactory.createStructType( + field.subFields + .map(sub => new StructType.StructField(sub.name, toType(sub))) + .asJavaCollection + ) + case BQField.Type.ARRAY => + TypeFactory.createArrayType(toType(field.subFields.head)) + case BQField.Type.TIMESTAMP => TypeFactory.createSimpleType(TypeKind.TYPE_TIMESTAMP) + case BQField.Type.DATE => TypeFactory.createSimpleType(TypeKind.TYPE_DATE) + case BQField.Type.TIME => TypeFactory.createSimpleType(TypeKind.TYPE_TIME) + case BQField.Type.DATETIME => TypeFactory.createSimpleType(TypeKind.TYPE_DATETIME) + case BQField.Type.GEOGRAPHY => TypeFactory.createSimpleType(TypeKind.TYPE_GEOGRAPHY) + case BQField.Type.JSON => TypeFactory.createSimpleType(TypeKind.TYPE_JSON) + case BQField.Type.INTERVAL => TypeFactory.createSimpleType(TypeKind.TYPE_INTERVAL) + } + if (isArray) TypeFactory.createArrayType(elemType) else elemType + } + + def toSimpleField(field: BQField) = + new SimpleColumn(table.tableId.tableName, field.name, toType(field), false, true) + + val simple = table match { + case BQTableRef(tableId, _, _) => + new SimpleTable(tableId.tableName) + + case tbl: BQTableDef.Table[_] => + new SimpleTable( + tbl.tableId.tableName, + new java.util.ArrayList(tbl.schema.fields.map(toSimpleField).asJavaCollection) + ) + case view: BQTableDef.ViewLike[_] => + new SimpleTable( + view.tableId.tableName, + new java.util.ArrayList(view.schema.fields.map(toSimpleField).asJavaCollection) + ) + } + simple.setFullName(table.tableId.asString) + simple + } +} diff --git a/zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala b/zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala new file mode 100644 index 00000000..2443f630 --- /dev/null +++ b/zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala @@ -0,0 +1,90 @@ +package no.nrk.bigquery + +import cats.effect.IO +import no.nrk.bigquery.syntax._ +import com.google.zetasql.toolkit.AnalysisException + +import java.time.LocalDate + +class ZetaTest extends munit.CatsEffectSuite { + private val zetaSql = new ZetaSql[IO] + + private val table = BQTableDef.Table( + BQTableId.unsafeOf(BQDataset.unsafeOf(ProjectId("com-example"), "example"), "test"), + BQSchema.of( + BQField("partitionDate", BQField.Type.DATE, BQField.Mode.REQUIRED), + BQField("a", BQField.Type.STRING, BQField.Mode.REQUIRED), + BQField("b", BQField.Type.INT64, BQField.Mode.REQUIRED), + BQField("c", BQField.Type.INT64, BQField.Mode.REQUIRED), + BQField("d", BQField.Type.INT64, BQField.Mode.REQUIRED) + ), + BQPartitionType.DatePartitioned(Ident("partitionDate")) + ) + + test("parses select 1") { + zetaSql.analyzeFirst(bqsql"select 1").map(_.isRight).assertEquals(true) + } + + test("fails to parse select from foo") { + zetaSql.analyzeFirst(bqsql"select from foo").flatMap(IO.fromEither).intercept[AnalysisException] + } + + test("subset select from example") { + val date = LocalDate.of(2023, 1, 1) + + val query = bqsql"select partitionDate, a, b, c from ${table.assertPartition(date)}" + + val expected = table.schema.fields.dropRight(1).map(_.recursivelyNullable.withoutDescription) + zetaSql.queryFields(query).assertEquals(expected) + } + + test("all fields should be selected from example") { + val date = LocalDate.of(2023, 1, 1) + + val query = bqsql"select partitionDate, a, b, c, d from ${table.assertPartition(date)}" + + val expected = table.schema.fields.map(_.recursivelyNullable.withoutDescription) + zetaSql.queryFields(query).assertEquals(expected) + } + + test("CTE selections") { + val query = + bqsql"""|with data as ( + | select partitionDate, a, b, c from ${table.unpartitioned} + |), + | grouped as ( + | select partitionDate, a, b, COUNTIF(c is null) as nullableCs from data + | group by 1, 2, 3 + | ) + |select * from grouped + |""".stripMargin + + val expected = + (table.schema.fields.dropRight(2) ++ List(BQField("nullableCs", BQField.Type.INT64, BQField.Mode.NULLABLE))) + .map(_.recursivelyNullable.withoutDescription) + + zetaSql.queryFields(query).assertEquals(expected) + } + + test("parse then build analysis") { + val query = + """|with data as ( + | select partitionDate, a, b, c from `com-example.example.test` + |), + | grouped as ( + | select partitionDate, a, b, COUNTIF(c is null) as nullableCs from data + | group by 1, 2, 3 + | ) + |select * from grouped + |""".stripMargin + + val expected = + (table.schema.fields.dropRight(2) ++ List(BQField("nullableCs", BQField.Type.INT64, BQField.Mode.NULLABLE))) + .map(_.recursivelyNullable.withoutDescription) + + zetaSql + .parseAndBuildAnalysableFragment(query, List(table)) + .flatMap(zetaSql.queryFields) + .assertEquals(expected) + } +}