Skip to content

Commit

Permalink
use tagless final
Browse files Browse the repository at this point in the history
  • Loading branch information
hamnis committed Jul 28, 2023
1 parent c3fefe4 commit 7d880a9
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 38 deletions.
70 changes: 39 additions & 31 deletions zetasql/src/main/scala/no/nrk/bigquery/ZetaSql.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package no.nrk.bigquery

import cats.syntax.all._
import cats.effect.IO
import cats.effect.Sync
import no.nrk.bigquery.syntax._
import com.google.zetasql.{
AnalyzerOptions,
Expand All @@ -26,24 +26,26 @@ import scala.collection.mutable.ListBuffer
import scala.jdk.CollectionConverters._
import scala.jdk.OptionConverters._

object ZetaSql {
def parse(frag: BQSqlFrag): IO[Either[SqlException, BQSqlFrag]] = parseScript(frag).map(_.as(frag))
class ZetaSql[F[_]](implicit F: Sync[F]) {
import ZetaSql._
def parse(frag: BQSqlFrag): F[Either[SqlException, BQSqlFrag]] = parseScript(frag).map(_.as(frag))

def parseScript(frag: BQSqlFrag): IO[Either[SqlException, ASTNodes.ASTScript]] = IO.interruptible {
val options = BigQueryLanguageOptions.get()
def parseScript(frag: BQSqlFrag): F[Either[SqlException, ASTNodes.ASTScript]] =
F.interruptible {
val options = BigQueryLanguageOptions.get()

try
Right(Parser.parseScript(frag.asString, options))
catch {
case e: SqlException => Left(e) // only catch sql exception and let all others bubble up to IO
try
Right(Parser.parseScript(frag.asString, options))
catch {
case e: SqlException => Left(e) // only catch sql exception and let all others bubble up to IO
}
}
}

def parseAndBuildAnalysableFragment(
query: String,
allTables: List[BQTableLike[Any]],
toFragment: BQTableLike[Any] => BQSqlFrag = _.unpartitioned.bqShow,
eqv: (BQTableId, BQTableId) => Boolean = _ == _): IO[BQSqlFrag] = {
eqv: (BQTableId, BQTableId) => Boolean = _ == _): F[BQSqlFrag] = {

def evalFragments(
parsedTables: List[(BQTableId, ParseLocationRange)]
Expand All @@ -63,13 +65,14 @@ object ZetaSql {
}

parseScript(BQSqlFrag.Frag(query))
.flatMap(IO.fromEither)
.flatMap(F.fromEither)
.flatMap { script =>
val list = script.getStatementListNode.getStatementList
if (list.size() != 1) {
IO.raiseError(new IllegalArgumentException("Expects only one statement"))
Sync[F].raiseError[List[(no.nrk.bigquery.BQTableId, com.google.zetasql.ParseLocationRange)]](
new IllegalArgumentException("Expects only one statement"))
} else
IO {
Sync[F].delay {
val buffer = new ListBuffer[(BQTableId, ParseLocationRange)]
list.asScala.headOption.foreach(_.accept(new ParseTreeVisitor {
override def visit(node: ASTNodes.ASTTablePathExpression): Unit =
Expand All @@ -85,7 +88,7 @@ object ZetaSql {
.map(evalFragments)
}

def queryFields(frag: BQSqlFrag): IO[List[BQField]] =
def queryFields(frag: BQSqlFrag): F[List[BQField]] =
analyzeFirst(frag).map { res =>
val builder = List.newBuilder[BQField]

Expand All @@ -95,28 +98,33 @@ object ZetaSql {
tree.accept(new ResolvedNodes.Visitor {
override def visit(node: ResolvedNodes.ResolvedQueryStmt): Unit =
node.getOutputColumnList.asScala.foreach(col =>
builder += ZetaSql.fromColumnNameAndType(col.getColumn.getName, col.getColumn.getType))
builder += fromColumnNameAndType(col.getColumn.getName, col.getColumn.getType))
}))
builder.result()
}

def analyzeFirst(frag: BQSqlFrag): IO[Either[AnalysisException, AnalyzedStatement]] = IO.interruptible {
val tables = frag.allReferencedTables
val catalog = toCatalog(tables: _*)
val rendered = frag.asString
def analyzeFirst(frag: BQSqlFrag): F[Either[AnalysisException, AnalyzedStatement]] =
F.interruptible {
val tables = frag.allReferencedTables
val catalog = toCatalog(tables: _*)
val rendered = frag.asString

val options = BigQueryLanguageOptions.get()
val analyzerOptions = new AnalyzerOptions
analyzerOptions.setLanguageOptions(options)
analyzerOptions.setPreserveColumnAliases(true)
val options = BigQueryLanguageOptions.get()
val analyzerOptions = new AnalyzerOptions
analyzerOptions.setLanguageOptions(options)
analyzerOptions.setPreserveColumnAliases(true)

val analyser = new ZetaSQLToolkitAnalyzer(analyzerOptions)
val analyzed = analyser.analyzeStatements(rendered, catalog)
val analyser = new ZetaSQLToolkitAnalyzer(analyzerOptions)
val analyzed = analyser.analyzeStatements(rendered, catalog)

if (analyzed.hasNext)
Right(analyzed.next())
else Left(new AnalysisException("Unable to find any analyzed statements"))
}
if (analyzed.hasNext)
Right(analyzed.next())
else Left(new AnalysisException("Unable to find any analyzed statements"))
}
}

object ZetaSql {
def apply[F[_]: Sync]: ZetaSql[F] = new ZetaSql[F]

def toCatalog(tables: BQTableLike[Any]*): BasicCatalogWrapper = {
val catalog = new BasicCatalogWrapper()
Expand Down Expand Up @@ -192,6 +200,7 @@ object ZetaSql {

def toSimpleField(field: BQField) =
new SimpleColumn(table.tableId.tableName, field.name, toType(field), false, true)

val simple = table match {
case BQTableRef(tableId, _, _) =>
new SimpleTable(tableId.tableName)
Expand All @@ -210,5 +219,4 @@ object ZetaSql {
simple.setFullName(table.tableId.asString)
simple
}

}
16 changes: 9 additions & 7 deletions zetasql/src/test/scala/no/nrk/bigquery/ZetaTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ import com.google.zetasql.toolkit.AnalysisException
import java.time.LocalDate

class ZetaTest extends munit.CatsEffectSuite {
private val zetaSql = new ZetaSql[IO]

private val table = BQTableDef.Table(
BQTableId.unsafeOf(BQDataset.unsafeOf(ProjectId("com-example"), "example"), "test"),
BQSchema.of(
Expand All @@ -20,11 +22,11 @@ class ZetaTest extends munit.CatsEffectSuite {
)

test("parses select 1") {
ZetaSql.analyzeFirst(bqsql"select 1").map(_.isRight).assertEquals(true)
zetaSql.analyzeFirst(bqsql"select 1").map(_.isRight).assertEquals(true)
}

test("fails to parse select from foo") {
ZetaSql.analyzeFirst(bqsql"select from foo").flatMap(IO.fromEither).intercept[AnalysisException]
zetaSql.analyzeFirst(bqsql"select from foo").flatMap(IO.fromEither).intercept[AnalysisException]
}

test("subset select from example") {
Expand All @@ -33,7 +35,7 @@ class ZetaTest extends munit.CatsEffectSuite {
val query = bqsql"select partitionDate, a, b, c from ${table.assertPartition(date)}"

val expected = table.schema.fields.dropRight(1).map(_.recursivelyNullable.withoutDescription)
ZetaSql.queryFields(query).assertEquals(expected)
zetaSql.queryFields(query).assertEquals(expected)
}

test("all fields should be selected from example") {
Expand All @@ -42,7 +44,7 @@ class ZetaTest extends munit.CatsEffectSuite {
val query = bqsql"select partitionDate, a, b, c, d from ${table.assertPartition(date)}"

val expected = table.schema.fields.map(_.recursivelyNullable.withoutDescription)
ZetaSql.queryFields(query).assertEquals(expected)
zetaSql.queryFields(query).assertEquals(expected)
}

test("CTE selections") {
Expand All @@ -61,7 +63,7 @@ class ZetaTest extends munit.CatsEffectSuite {
(table.schema.fields.dropRight(2) ++ List(BQField("nullableCs", BQField.Type.INT64, BQField.Mode.NULLABLE)))
.map(_.recursivelyNullable.withoutDescription)

ZetaSql.queryFields(query).assertEquals(expected)
zetaSql.queryFields(query).assertEquals(expected)
}

test("parse then build analysis") {
Expand All @@ -80,9 +82,9 @@ class ZetaTest extends munit.CatsEffectSuite {
(table.schema.fields.dropRight(2) ++ List(BQField("nullableCs", BQField.Type.INT64, BQField.Mode.NULLABLE)))
.map(_.recursivelyNullable.withoutDescription)

ZetaSql
zetaSql
.parseAndBuildAnalysableFragment(query, List(table))
.flatMap(ZetaSql.queryFields)
.flatMap(zetaSql.queryFields)
.assertEquals(expected)
}
}

0 comments on commit 7d880a9

Please sign in to comment.