Skip to content

Commit

Permalink
Merge pull request #139 from nrkno/zetasql
Browse files Browse the repository at this point in the history
We can use offline zeta sql to verify queries
  • Loading branch information
hamnis authored Aug 22, 2023
2 parents 04af07f + 7d880a9 commit ee388c3
Show file tree
Hide file tree
Showing 4 changed files with 354 additions and 24 deletions.
46 changes: 23 additions & 23 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
matrix:
os: [ubuntu-latest]
scala: [2.13.11, 2.12.18, 3.3.0]
java: [temurin@8]
java: [temurin@11]
project: [rootJVM]
runs-on: ${{ matrix.os }}
steps:
Expand All @@ -33,21 +33,21 @@ jobs:
with:
fetch-depth: 0

- name: Download Java (temurin@8)
id: download-java-temurin-8
if: matrix.java == 'temurin@8'
- name: Download Java (temurin@11)
id: download-java-temurin-11
if: matrix.java == 'temurin@11'
uses: typelevel/download-java@v2
with:
distribution: temurin
java-version: 8
java-version: 11

- name: Setup Java (temurin@8)
if: matrix.java == 'temurin@8'
- name: Setup Java (temurin@11)
if: matrix.java == 'temurin@11'
uses: actions/setup-java@v3
with:
distribution: jdkfile
java-version: 8
jdkFile: ${{ steps.download-java-temurin-8.outputs.jdkFile }}
java-version: 11
jdkFile: ${{ steps.download-java-temurin-11.outputs.jdkFile }}

- name: Cache sbt
uses: actions/cache@v3
Expand All @@ -65,7 +65,7 @@ jobs:
run: sbt githubWorkflowCheck

- name: Check formatting
if: matrix.java == 'temurin@8' && matrix.os == 'ubuntu-latest'
if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest'
run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' scalafmtCheckAll 'project /' scalafmtSbtCheck

- name: Test
Expand All @@ -75,23 +75,23 @@ jobs:
run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' test

- name: Check binary compatibility
if: matrix.java == 'temurin@8' && matrix.os == 'ubuntu-latest'
if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest'
env:
MYGET_USERNAME: ${{ secrets.PLATTFORM_MYGET_ENTERPRISE_READ_ID }}
MYGET_PASSWORD: ${{ secrets.PLATTFORM_MYGET_ENTERPRISE_READ_SECRET }}
run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' mimaReportBinaryIssues

- name: Generate API documentation
if: matrix.java == 'temurin@8' && matrix.os == 'ubuntu-latest'
if: matrix.java == 'temurin@11' && matrix.os == 'ubuntu-latest'
run: sbt 'project ${{ matrix.project }}' '++ ${{ matrix.scala }}' doc

- name: Make target directories
if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v'))
run: mkdir -p target .js/target site/target prometheus/.jvm/target testing/.jvm/target .jvm/target .native/target core/.jvm/target project/target
run: mkdir -p target .js/target site/target prometheus/.jvm/target testing/.jvm/target .jvm/target .native/target core/.jvm/target zetasql/.jvm/target project/target

- name: Compress target directories
if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v'))
run: tar cf targets.tar target .js/target site/target prometheus/.jvm/target testing/.jvm/target .jvm/target .native/target core/.jvm/target project/target
run: tar cf targets.tar target .js/target site/target prometheus/.jvm/target testing/.jvm/target .jvm/target .native/target core/.jvm/target zetasql/.jvm/target project/target

- name: Upload target directories
if: github.event_name != 'pull_request' && (startsWith(github.ref, 'refs/tags/v'))
Expand All @@ -107,29 +107,29 @@ jobs:
strategy:
matrix:
os: [ubuntu-latest]
java: [temurin@8]
java: [temurin@11]
runs-on: ${{ matrix.os }}
steps:
- name: Checkout current branch (full)
uses: actions/checkout@v3
with:
fetch-depth: 0

- name: Download Java (temurin@8)
id: download-java-temurin-8
if: matrix.java == 'temurin@8'
- name: Download Java (temurin@11)
id: download-java-temurin-11
if: matrix.java == 'temurin@11'
uses: typelevel/download-java@v2
with:
distribution: temurin
java-version: 8
java-version: 11

- name: Setup Java (temurin@8)
if: matrix.java == 'temurin@8'
- name: Setup Java (temurin@11)
if: matrix.java == 'temurin@11'
uses: actions/setup-java@v3
with:
distribution: jdkfile
java-version: 8
jdkFile: ${{ steps.download-java-temurin-8.outputs.jdkFile }}
java-version: 11
jdkFile: ${{ steps.download-java-temurin-11.outputs.jdkFile }}

- name: Cache sbt
uses: actions/cache@v3
Expand Down
20 changes: 19 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ ThisBuild / tlVersionIntroduced := Map(
"3" -> "0.1.1",
"2.13" -> "0.1.0"
)
ThisBuild / githubWorkflowJavaVersions := Seq(JavaSpec.temurin("11"))

val commonSettings = Seq(
resolvers += "MyGet - datahub".at(s"https://nrk.myget.org/F/datahub/maven/"),
Expand Down Expand Up @@ -94,7 +95,7 @@ val commonSettings = Seq(

lazy val root = tlCrossRootProject
.settings(name := "bigquery-scala")
.aggregate(core, testing, prometheus, docs)
.aggregate(core, testing, prometheus, zetasql, docs)
.disablePlugins(TypelevelCiSigningPlugin, Sonatype, SbtGpg)

lazy val core = crossProject(JVMPlatform)
Expand Down Expand Up @@ -156,6 +157,23 @@ lazy val prometheus = crossProject(JVMPlatform)
)
.disablePlugins(TypelevelCiSigningPlugin, Sonatype, SbtGpg)

lazy val zetasql = crossProject(JVMPlatform)
.withoutSuffixFor(JVMPlatform)
.crossType(CrossType.Pure)
.in(file("zetasql"))
.settings(commonSettings)
.dependsOn(core)
.settings(
name := "bigquery-zetasql",
tlMimaPreviousVersions := Set.empty,
libraryDependencies ++= Seq(
"com.google.zetasql.toolkit" % "zetasql-toolkit-bigquery" % "0.4.0",
"org.scalameta" %% "munit" % "0.7.29",
"org.typelevel" %% "munit-cats-effect-3" % "1.0.7"
)
)
.disablePlugins(TypelevelCiSigningPlugin, Sonatype, SbtGpg)

lazy val testing = crossProject(JVMPlatform)
.withoutSuffixFor(JVMPlatform)
.crossType(CrossType.Pure)
Expand Down
222 changes: 222 additions & 0 deletions zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
package no.nrk.bigquery

import cats.syntax.all._
import cats.effect.Sync
import no.nrk.bigquery.syntax._
import com.google.zetasql.{
AnalyzerOptions,
ParseLocationRange,
Parser,
SimpleColumn,
SimpleTable,
SqlException,
StructType,
Type,
TypeFactory
}
import com.google.zetasql.ZetaSQLType.TypeKind
import com.google.zetasql.resolvedast.ResolvedCreateStatementEnums.{CreateMode, CreateScope}
import com.google.zetasql.resolvedast.ResolvedNodes
import com.google.zetasql.toolkit.catalog.basic.BasicCatalogWrapper
import com.google.zetasql.toolkit.options.BigQueryLanguageOptions
import com.google.zetasql.parser.{ASTNodes, ParseTreeVisitor}
import com.google.zetasql.toolkit.{AnalysisException, AnalyzedStatement, ZetaSQLToolkitAnalyzer}

import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._
import scala.jdk.OptionConverters._

class ZetaSql[F[_]](implicit F: Sync[F]) {
import ZetaSql._
def parse(frag: BQSqlFrag): F[Either[SqlException, BQSqlFrag]] = parseScript(frag).map(_.as(frag))

def parseScript(frag: BQSqlFrag): F[Either[SqlException, ASTNodes.ASTScript]] =
F.interruptible {
val options = BigQueryLanguageOptions.get()

try
Right(Parser.parseScript(frag.asString, options))
catch {
case e: SqlException => Left(e) // only catch sql exception and let all others bubble up to IO
}
}

def parseAndBuildAnalysableFragment(
query: String,
allTables: List[BQTableLike[Any]],
toFragment: BQTableLike[Any] => BQSqlFrag = _.unpartitioned.bqShow,
eqv: (BQTableId, BQTableId) => Boolean = _ == _): F[BQSqlFrag] = {

def evalFragments(
parsedTables: List[(BQTableId, ParseLocationRange)]
): BQSqlFrag = {
val asString = query
val found = allTables
.flatMap(table =>
parsedTables.flatMap { case (id, range) => if (eqv(table.tableId, id)) List(table -> range) else Nil })
.distinct
val (rest, aggregate) = found.foldLeft((asString, BQSqlFrag.Empty)) { case ((input, agg), (t, loc)) =>
val frag = agg ++ BQSqlFrag.Frag(input.substring(0, loc.start() - 1)) ++ toFragment(t)
val rest = input.substring(loc.end())
rest -> frag
}

aggregate ++ BQSqlFrag.Frag(rest)
}

parseScript(BQSqlFrag.Frag(query))
.flatMap(F.fromEither)
.flatMap { script =>
val list = script.getStatementListNode.getStatementList
if (list.size() != 1) {
Sync[F].raiseError[List[(no.nrk.bigquery.BQTableId, com.google.zetasql.ParseLocationRange)]](
new IllegalArgumentException("Expects only one statement"))
} else
Sync[F].delay {
val buffer = new ListBuffer[(BQTableId, ParseLocationRange)]
list.asScala.headOption.foreach(_.accept(new ParseTreeVisitor {
override def visit(node: ASTNodes.ASTTablePathExpression): Unit =
node.getPathExpr.getNames.forEach(ident =>
BQTableId
.fromString(ident.getIdString)
.toOption
.foreach(id => buffer += (id -> ident.getParseLocationRange)))
}))
buffer.toList
}
}
.map(evalFragments)
}

def queryFields(frag: BQSqlFrag): F[List[BQField]] =
analyzeFirst(frag).map { res =>
val builder = List.newBuilder[BQField]

res
.flatMap(_.getResolvedStatement.toScala.toRight(new AnalysisException("No analysis found")))
.foreach(tree =>
tree.accept(new ResolvedNodes.Visitor {
override def visit(node: ResolvedNodes.ResolvedQueryStmt): Unit =
node.getOutputColumnList.asScala.foreach(col =>
builder += fromColumnNameAndType(col.getColumn.getName, col.getColumn.getType))
}))
builder.result()
}

def analyzeFirst(frag: BQSqlFrag): F[Either[AnalysisException, AnalyzedStatement]] =
F.interruptible {
val tables = frag.allReferencedTables
val catalog = toCatalog(tables: _*)
val rendered = frag.asString

val options = BigQueryLanguageOptions.get()
val analyzerOptions = new AnalyzerOptions
analyzerOptions.setLanguageOptions(options)
analyzerOptions.setPreserveColumnAliases(true)

val analyser = new ZetaSQLToolkitAnalyzer(analyzerOptions)
val analyzed = analyser.analyzeStatements(rendered, catalog)

if (analyzed.hasNext)
Right(analyzed.next())
else Left(new AnalysisException("Unable to find any analyzed statements"))
}
}

object ZetaSql {
def apply[F[_]: Sync]: ZetaSql[F] = new ZetaSql[F]

def toCatalog(tables: BQTableLike[Any]*): BasicCatalogWrapper = {
val catalog = new BasicCatalogWrapper()
tables.foreach(table =>
catalog.register(toSimpleTable(table), CreateMode.CREATE_IF_NOT_EXISTS, CreateScope.CREATE_DEFAULT_SCOPE))
catalog
}

def fromColumnNameAndType(name: String, typ: Type): BQField = {
val kind = typ.getKind match {
case TypeKind.TYPE_BOOL => BQField.Type.BOOL
case TypeKind.TYPE_DATE => BQField.Type.DATE
case TypeKind.TYPE_DATETIME => BQField.Type.DATETIME
case TypeKind.TYPE_JSON => BQField.Type.JSON
case TypeKind.TYPE_BYTES => BQField.Type.BYTES
case TypeKind.TYPE_STRING => BQField.Type.STRING
case TypeKind.TYPE_BIGNUMERIC => BQField.Type.BIGNUMERIC
case TypeKind.TYPE_INT64 => BQField.Type.INT64
case TypeKind.TYPE_INT32 => BQField.Type.INT64
case TypeKind.TYPE_FLOAT => BQField.Type.FLOAT64
case TypeKind.TYPE_DOUBLE => BQField.Type.FLOAT64
case TypeKind.TYPE_TIMESTAMP => BQField.Type.TIMESTAMP
case TypeKind.TYPE_TIME => BQField.Type.TIME
case TypeKind.TYPE_GEOGRAPHY => BQField.Type.GEOGRAPHY
case TypeKind.TYPE_INTERVAL => BQField.Type.INTERVAL
case _ => throw new IllegalArgumentException(s"$name with type ${typ.debugString()} is not supported ")
}

if (typ.isArray) {
val elem = fromColumnNameAndType(name, typ.asArray().getElementType)
elem.copy(mode = BQField.Mode.REPEATED)
} else if (typ.isStruct) {
BQField.struct(name, BQField.Mode.NULLABLE)(
typ
.asStruct()
.getFieldList
.asScala
.map(subField => fromColumnNameAndType(subField.getName, subField.getType))
.toList: _*)
} else BQField(name, kind, BQField.Mode.NULLABLE)
}

def toSimpleTable(table: BQTableLike[Any]): SimpleTable = {
def toType(field: BQField): Type = {
val isArray = field.mode == BQField.Mode.REPEATED

val elemType = field.tpe match {
case BQField.Type.BOOL => TypeFactory.createSimpleType(TypeKind.TYPE_BOOL)
case BQField.Type.INT64 => TypeFactory.createSimpleType(TypeKind.TYPE_INT64)
case BQField.Type.FLOAT64 => TypeFactory.createSimpleType(TypeKind.TYPE_FLOAT)
case BQField.Type.NUMERIC => TypeFactory.createSimpleType(TypeKind.TYPE_NUMERIC)
case BQField.Type.BIGNUMERIC => TypeFactory.createSimpleType(TypeKind.TYPE_BIGNUMERIC)
case BQField.Type.STRING => TypeFactory.createSimpleType(TypeKind.TYPE_STRING)
case BQField.Type.BYTES => TypeFactory.createSimpleType(TypeKind.TYPE_BYTES)
case BQField.Type.STRUCT =>
TypeFactory.createStructType(
field.subFields
.map(sub => new StructType.StructField(sub.name, toType(sub)))
.asJavaCollection
)
case BQField.Type.ARRAY =>
TypeFactory.createArrayType(toType(field.subFields.head))
case BQField.Type.TIMESTAMP => TypeFactory.createSimpleType(TypeKind.TYPE_TIMESTAMP)
case BQField.Type.DATE => TypeFactory.createSimpleType(TypeKind.TYPE_DATE)
case BQField.Type.TIME => TypeFactory.createSimpleType(TypeKind.TYPE_TIME)
case BQField.Type.DATETIME => TypeFactory.createSimpleType(TypeKind.TYPE_DATETIME)
case BQField.Type.GEOGRAPHY => TypeFactory.createSimpleType(TypeKind.TYPE_GEOGRAPHY)
case BQField.Type.JSON => TypeFactory.createSimpleType(TypeKind.TYPE_JSON)
case BQField.Type.INTERVAL => TypeFactory.createSimpleType(TypeKind.TYPE_INTERVAL)
}
if (isArray) TypeFactory.createArrayType(elemType) else elemType
}

def toSimpleField(field: BQField) =
new SimpleColumn(table.tableId.tableName, field.name, toType(field), false, true)

val simple = table match {
case BQTableRef(tableId, _, _) =>
new SimpleTable(tableId.tableName)

case tbl: BQTableDef.Table[_] =>
new SimpleTable(
tbl.tableId.tableName,
new java.util.ArrayList(tbl.schema.fields.map(toSimpleField).asJavaCollection)
)
case view: BQTableDef.ViewLike[_] =>
new SimpleTable(
view.tableId.tableName,
new java.util.ArrayList(view.schema.fields.map(toSimpleField).asJavaCollection)
)
}
simple.setFullName(table.tableId.asString)
simple
}
}
Loading

0 comments on commit ee388c3

Please sign in to comment.