Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[POC][DON'T MERGE] Transform ANTLR4 to SqlNode #999

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ lazy val jacksonVersion = "2.15.2"
lazy val opensearchVersion = "2.6.0"
lazy val opensearchMavenVersion = "2.6.0.0"
lazy val icebergVersion = "1.5.0"
lazy val calciteVersion = "1.37.0"

val scalaMinorVersion = scala212.split("\\.").take(2).mkString(".")
val sparkMinorVersion = sparkVersion.split("\\.").take(2).mkString(".")
Expand Down Expand Up @@ -120,6 +121,7 @@ lazy val flintCore = (project in file("flint-core"))
exclude ("com.fasterxml.jackson.core", "jackson-core")
exclude ("org.apache.httpcomponents.client5", "httpclient5"),
"org.opensearch" % "opensearch-job-scheduler-spi" % opensearchMavenVersion,
"org.apache.calcite" % "calcite-core" % calciteVersion,
"dev.failsafe" % "failsafe" % "3.3.2",
"com.amazonaws" % "aws-java-sdk" % "1.12.397" % "provided"
exclude ("com.fasterxml.jackson.core", "jackson-databind"),
Expand Down Expand Up @@ -193,6 +195,7 @@ lazy val pplSparkIntegration = (project in file("ppl-spark-integration"))
"com.stephenn" %% "scalatest-json-jsonassert" % "0.2.5" % "test",
"com.github.sbt" % "junit-interface" % "0.13.3" % "test",
"org.projectlombok" % "lombok" % "1.18.30",
"org.apache.calcite" % "calcite-core" % calciteVersion,
"com.github.seancfoley" % "ipaddress" % "5.5.1",
),
libraryDependencies ++= deps(sparkVersion),
Expand Down Expand Up @@ -228,6 +231,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration"))
libraryDependencies ++= Seq(
"com.amazonaws" % "aws-java-sdk" % "1.12.397" % "provided"
exclude ("com.fasterxml.jackson.core", "jackson-databind"),
"org.apache.calcite" % "calcite-core" % calciteVersion,
"org.scalactic" %% "scalactic" % "3.2.15" % "test",
"org.scalatest" %% "scalatest" % "3.2.15" % "test",
"org.scalatest" %% "scalatest-flatspec" % "3.2.15" % "test",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.flint.spark.ppl

import java.util

import org.apache.calcite.sql.parser.SqlParserPos.ZERO
import org.apache.calcite.sql.{SqlCall, SqlLiteral, SqlNode, SqlOperator}

case class Function(functionName: String, sqlOperator: SqlOperator) {

def createCall(function: SqlNode, operands: util.List[SqlNode], qualifier: SqlLiteral): SqlCall =
sqlOperator.createCall(qualifier, ZERO, operands)
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.flint.spark.ppl

import org.apache.calcite.sql.SqlOperator
import org.apache.calcite.sql.fun.SqlStdOperatorTable

case class PPLFunctionResolver() {
def resolve(name: String): SqlOperator = {
name match {
case "=" => SqlStdOperatorTable.EQUALS
case "avg" => SqlStdOperatorTable.AVG
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@
*/
package org.opensearch.flint.spark.ppl

import org.antlr.v4.runtime.{CommonTokenStream, Lexer}
import org.antlr.v4.runtime.tree.ParseTree
import org.antlr.v4.runtime.{CommonTokenStream, Lexer}
import org.opensearch.sql.ast.statement.Statement
import org.opensearch.sql.common.antlr.{CaseInsensitiveCharStream, Parser, SyntaxAnalysisErrorListener}
import org.opensearch.sql.ppl.parser.{AstBuilder, AstExpressionBuilder, AstStatementBuilder}
import org.opensearch.sql.ppl.parser.{AstBuilder, AstStatementBuilder}

class PPLSyntaxParser extends Parser {
// Analyze the query syntax
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/
package org.opensearch.flint.spark.ppl

import scala.collection.JavaConverters._

import org.antlr.v4.runtime.CommonTokenStream
import org.antlr.v4.runtime.atn.PredictionMode
import org.antlr.v4.runtime.misc.ParseCancellationException
import org.opensearch.sql.common.antlr.{CaseInsensitiveCharStream, SyntaxAnalysisErrorListener}

import org.apache.calcite.sql.fun.SqlStdOperatorTable
import org.apache.calcite.sql.parser.SqlParserPos.ZERO
import org.apache.calcite.sql.{SqlBasicCall, SqlIdentifier, SqlLiteral, SqlNode, SqlNodeList, SqlSelect}


class PPLParser {
val astBuilder = new PPLAstBuilder()

def parseQuery(query: String): SqlNode = parse(query) { parser =>
val ctx = parser.root().pplStatement()
val a = astBuilder.visit(ctx)
a
}

protected def parse[T](command: String)(toResult: OpenSearchPPLParser => T): T = {
val lexer = new OpenSearchPPLLexer(new CaseInsensitiveCharStream(command))
// lexer.removeErrorListeners()
// lexer.addErrorListener(ParseErrorListener)
lexer.addErrorListener(new SyntaxAnalysisErrorListener())

val tokenStream = new CommonTokenStream(lexer)
val parser = new OpenSearchPPLParser(tokenStream)
parser.addErrorListener(new SyntaxAnalysisErrorListener())
// parser.addParseListener(PostProcessor)
// parser.addParseListener(UnclosedCommentProcessor(command, tokenStream))
// parser.removeErrorListeners()
// parser.addErrorListener(ParseErrorListener)
/*
parser.legacy_setops_precedence_enabled = conf.setOpsPrecedenceEnforced
parser.legacy_exponent_literal_as_decimal_enabled = conf.exponentLiteralAsDecimalEnabled
parser.SQL_standard_keyword_behavior = conf.enforceReservedKeywords
parser.double_quoted_identifiers = conf.doubleQuotedIdentifiers
*/

// https://github.com/antlr/antlr4/issues/192#issuecomment-15238595
// Save a great deal of time on correct inputs by using a two-stage parsing strategy.
try {
try {
// first, try parsing with potentially faster SLL mode w/ SparkParserBailErrorStrategy
// parser.setErrorHandler(new SparkParserBailErrorStrategy())
parser.getInterpreter.setPredictionMode(PredictionMode.SLL)
val a = toResult(parser)
a
}
catch {
case e: ParseCancellationException =>
// if we fail, parse with LL mode w/ SparkParserErrorStrategy
tokenStream.seek(0) // rewind input stream
parser.reset()

// Try Again.
// parser.setErrorHandler(new SparkParserErrorStrategy())
parser.getInterpreter.setPredictionMode(PredictionMode.LL)
toResult(parser)
}
}
}
}

class PPLAstBuilder extends OpenSearchPPLParserBaseVisitor[SqlNode] {
val functionResolver = PPLFunctionResolver();

override def visitDmlStatement(ctx: OpenSearchPPLParser.DmlStatementContext): SqlNode = {
visit(ctx.queryStatement())
}

override def visitQueryStatement(ctx: OpenSearchPPLParser.QueryStatementContext): SqlNode = {
val source = visit(ctx.pplCommands()).asInstanceOf[SqlSelect]
val commands = ctx.commands().asScala.map(visit).map(_.asInstanceOf[SqlSelect])
val result: SqlSelect = commands.foldLeft(source) {(pre: SqlNode, cur: SqlSelect) =>
cur.setFrom(pre)
cur
}
result
}

override def visitPplCommands(ctx: OpenSearchPPLParser.PplCommandsContext): SqlNode = {
val from = visit(ctx.searchCommand())
new SqlSelect(ZERO, null, SqlNodeList.SINGLETON_STAR, from, null, null, null, null, null, null, null, null)
}

override def visitFromClause(ctx: OpenSearchPPLParser.FromClauseContext): SqlNode = {
super.visitFromClause(ctx)
}

override def visitTableOrSubqueryClause(ctx: OpenSearchPPLParser.TableOrSubqueryClauseContext): SqlNode = {
if (ctx.subSearch() != null) {
null
} else {
visitTableSourceClause(ctx.tableSourceClause());
}
}

override def visitTableSourceClause(ctx: OpenSearchPPLParser.TableSourceClauseContext): SqlNode = {
var sqlNodes = Seq[SqlNode]()
for (i <- 0 until ctx.tableSource().size) {
sqlNodes :+= visitTableSource(ctx.tableSource(i)).asInstanceOf[SqlNode]
}
// val sqlNodes = ctx.tableSource().stream().map(a => visitTableSource(a)).collect(Collectors.toList)
if (ctx.alias == null) {
sqlNodes.head
// sqlNodes.get(0)
//} else new SqlBasicCall(SqlStdOperatorTable.AS, sqlNodes.toArray(new Array[SqlNode](0)), ZERO)
} else new SqlBasicCall(SqlStdOperatorTable.AS, sqlNodes.toArray, ZERO)
}

override def visitIdentsAsTableQualifiedName(ctx: OpenSearchPPLParser.IdentsAsTableQualifiedNameContext): SqlNode = {
new SqlIdentifier(ctx.tableIdent().ident().getText, ZERO)
}

override def visitWhereCommand(ctx: OpenSearchPPLParser.WhereCommandContext): SqlNode = {
val where = visitChildren(ctx)
new SqlSelect(ZERO, null, SqlNodeList.SINGLETON_STAR, null, where, null, null, null, null, null, null, null)
}

override def visitComparsion(ctx: OpenSearchPPLParser.ComparsionContext): SqlNode = {
super.visitComparsion(ctx)
}

override def visitCompareExpr(ctx: OpenSearchPPLParser.CompareExprContext): SqlNode = {
functionResolver.resolve(ctx.comparisonOperator.getText).createCall(null, ZERO, visit(ctx.left), visit(ctx.right))
}

override def visitIdentsAsQualifiedName(ctx: OpenSearchPPLParser.IdentsAsQualifiedNameContext): SqlNode = {
new SqlIdentifier(ctx.ident().asScala.map(_.getText).reduce((a, b) => a + "." + b), ZERO)
}

override def visitIdent(ctx: OpenSearchPPLParser.IdentContext): SqlNode = {
new SqlIdentifier(ctx.getText, ZERO)
}

override def visitIntegerLiteral(ctx: OpenSearchPPLParser.IntegerLiteralContext): SqlNode = {
SqlLiteral.createExactNumeric(ctx.getText, ZERO)
}

override def visitFieldsCommand(ctx: OpenSearchPPLParser.FieldsCommandContext): SqlNode = {
val selectExpr = visitFieldList(ctx.fieldList())
new SqlSelect(ZERO, null, selectExpr, null, null, null, null, null, null, null, null, null)
}

override def visitFieldList(ctx: OpenSearchPPLParser.FieldListContext): SqlNodeList = {
val fields = ctx.fieldExpression.asScala.map(visit)
SqlNodeList.of(ZERO, fields.asJava)
}

override def visitSortCommand(ctx: OpenSearchPPLParser.SortCommandContext): SqlNode = {
val orderByList = visitSortbyClause(ctx.sortbyClause())
new SqlSelect(ZERO, null, SqlNodeList.SINGLETON_STAR, null, null, null, null, null, orderByList, null, null, null)
}

override def visitSortbyClause(ctx: OpenSearchPPLParser.SortbyClauseContext): SqlNodeList = {
val fields = ctx.sortField().asScala.map(visit)
SqlNodeList.of(ZERO, fields.asJava)
}

override def visitStatsCommand(ctx: OpenSearchPPLParser.StatsCommandContext): SqlNode = {
val aggList = ctx.statsAggTerm.asScala.map(visit)
val groupByList = visitStatsByClause(ctx.statsByClause())
new SqlSelect(ZERO, null, SqlNodeList.of(ZERO, (groupByList.getList.asScala ++ aggList).asJava), null, null, groupByList, null, null, null, null, null, null)
}

override def visitStatsAggTerm(ctx: OpenSearchPPLParser.StatsAggTermContext): SqlNode = {
val agg = visit(ctx.statsFunction())

if (ctx.alias == null) agg else {
val alias = visit(ctx.alias)
new SqlBasicCall(SqlStdOperatorTable.AS, Seq(agg, alias).asJava.toArray(new Array[SqlNode](0)), ZERO)
}
}

override def visitStatsFunctionCall(ctx: OpenSearchPPLParser.StatsFunctionCallContext): SqlNode = {
functionResolver.resolve(ctx.statsFunctionName.getText).createCall(null, ZERO, visit(ctx.valueExpression))
}

override def visitStatsByClause(ctx: OpenSearchPPLParser.StatsByClauseContext): SqlNodeList = {
SqlNodeList.of(ZERO, ctx.fieldList().fieldExpression().asScala.map(visit).asJava)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import java.util
import java.util.Objects.requireNonNull

import org.scalatest.matchers.should.Matchers

import org.apache.calcite.adapter.java.AbstractQueryableTable
import org.apache.calcite.config.{CalciteConnectionConfig, Lex}
import org.apache.calcite.jdbc.{CalciteSchema, JavaTypeFactoryImpl}
import org.apache.calcite.linq4j.{Enumerable, Linq4j, QueryProvider, Queryable}
import org.apache.calcite.plan.RelOptCluster
import org.apache.calcite.plan.volcano.VolcanoPlanner
import org.apache.calcite.prepare.{CalciteCatalogReader, PlannerImpl}
import org.apache.calcite.rel.`type`.{RelDataType, RelDataTypeFactory}
import org.apache.calcite.rel.rel2sql.RelToSqlConverter
import org.apache.calcite.rex.RexBuilder
import org.apache.calcite.schema.SchemaPlus
import org.apache.calcite.schema.impl.AbstractTable
import org.apache.calcite.sql.SqlDialect.DatabaseProduct
import org.apache.calcite.sql.`type`.SqlTypeName
import org.apache.calcite.sql.parser.SqlParser
import org.apache.calcite.sql2rel.SqlToRelConverter
import org.apache.calcite.tools.{FrameworkConfig, Frameworks, Programs}
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.plans.PlanTest

class PPLSqlNodeTestSuite
extends SparkFunSuite
with PlanTest
with LogicalPlanTestUtils
with Matchers {

val t: AbstractTable = new AbstractQueryableTable(classOf[Integer]) {
val enumerable: Enumerable[Integer] = Linq4j.asEnumerable(new util.ArrayList[Integer]())

override def asQueryable[E](queryProvider: QueryProvider, schema: SchemaPlus, tableName: String): Queryable[E] = enumerable.asQueryable.asInstanceOf[Queryable[E]]

override def getRowType(typeFactory: RelDataTypeFactory): RelDataType = {
val builder: RelDataTypeFactory.Builder = typeFactory.builder
builder.add("a", SqlTypeName.INTEGER)
builder.add("b", SqlTypeName.INTEGER)
builder.add("c", SqlTypeName.INTEGER)
for (i <- 0 until 3) {
builder.add(s"c$i", SqlTypeName.INTEGER)
}
builder.build
}
}

private def createCatalogReader = {
val defaultSchema = requireNonNull(config.getDefaultSchema, "defaultSchema")
val rootSchema = defaultSchema
new CalciteCatalogReader(CalciteSchema.from(rootSchema), CalciteSchema.from(defaultSchema).path(null), typeFactory, CalciteConnectionConfig.DEFAULT)
}

val schema: SchemaPlus = Frameworks.createRootSchema(true)
schema.add("table", t)
val config: FrameworkConfig = Frameworks.newConfigBuilder
.parserConfig(SqlParser.config.withLex(Lex.MYSQL))
.defaultSchema(schema)
.programs(Programs.ofRules(Programs.RULE_SET))
.build
val typeFactory = new JavaTypeFactoryImpl(config.getTypeSystem)
val pplParser = new PPLParser()
val planner = Frameworks.getPlanner(config)
val cluster: RelOptCluster = RelOptCluster.create(requireNonNull(new VolcanoPlanner(config.getCostFactory, config.getContext), "planner"), new RexBuilder(typeFactory))
val sqlToRelConverter = new SqlToRelConverter(planner.asInstanceOf[PlannerImpl], null, createCatalogReader, cluster, config.getConvertletTable, config.getSqlToRelConverterConfig)
val relToSqlConverter = new RelToSqlConverter(DatabaseProduct.CALCITE.getDialect)
val pplParserOld = new PPLSyntaxParser()

test("test") {
val sqlNode = pplParser.parseQuery("source=table | where a = 1| stats avg(b) as avg_b by c | sort c | fields c, avg_b")
val relNode = sqlToRelConverter.convertQuery(sqlNode, false, true)

val sqlNode2 = planner.parse(sqlNode.toString())
planner.validate(sqlNode2)
val relNode2 = planner.rel(sqlNode2)
val sqlNode3 = relToSqlConverter.visitRoot(relNode.rel).asStatement()

// val relNode = planner.rel(sqlNode)
// val osPlan = plan(pplParserOld, "source=t")
//scalastyle:off
println(sqlNode)
println(relNode2)
println(sqlNode3)
// println(osPlan)
//scalastyle:on
}

}
Loading