diff --git a/build.sbt b/build.sbt index a3d5141b3..f163f76a2 100644 --- a/build.sbt +++ b/build.sbt @@ -61,12 +61,11 @@ lazy val flintCore = (project in file("flint-core")) exclude ("com.fasterxml.jackson.core", "jackson-databind")), publish / skip := true) -lazy val flintSparkIntegration = (project in file("flint-spark-integration")) - .dependsOn(flintCore) +lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) .enablePlugins(AssemblyPlugin, Antlr4Plugin) .settings( commonSettings, - name := "flint-spark-integration", + name := "ppl-spark-integration", scalaVersion := scala212, libraryDependencies ++= Seq( "com.amazonaws" % "aws-java-sdk" % "1.12.397" % "provided" @@ -80,7 +79,7 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) libraryDependencies ++= deps(sparkVersion), // ANTLR settings Antlr4 / antlr4Version := "4.8", - Antlr4 / antlr4PackageName := Some("org.opensearch.flint.spark.sql"), + Antlr4 / antlr4PackageName := Some("org.opensearch.flint.spark.ppl"), Antlr4 / antlr4GenListener := true, Antlr4 / antlr4GenVisitor := true, // Assembly settings @@ -99,11 +98,13 @@ lazy val flintSparkIntegration = (project in file("flint-spark-integration")) oldStrategy(x) }, assembly / test := (Test / test).value) -lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) + +lazy val flintSparkIntegration = (project in file("flint-spark-integration")) + .dependsOn(flintCore, pplSparkIntegration) .enablePlugins(AssemblyPlugin, Antlr4Plugin) .settings( commonSettings, - name := "ppl-spark-integration", + name := "flint-spark-integration", scalaVersion := scala212, libraryDependencies ++= Seq( "com.amazonaws" % "aws-java-sdk" % "1.12.397" % "provided" @@ -117,7 +118,7 @@ lazy val pplSparkIntegration = (project in file("ppl-spark-integration")) libraryDependencies ++= deps(sparkVersion), // ANTLR settings Antlr4 / antlr4Version := "4.8", - Antlr4 / antlr4PackageName := Some("org.opensearch.flint.spark.ppl"), + Antlr4 / antlr4PackageName := Some("org.opensearch.flint.spark.sql"), Antlr4 / antlr4GenListener := true, Antlr4 / antlr4GenVisitor := true, // Assembly settings diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintGenericSparkExtensions.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintGenericSparkExtensions.scala new file mode 100644 index 000000000..6a890dde3 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintGenericSparkExtensions.scala @@ -0,0 +1,25 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark + +import org.apache.spark.sql.SparkSessionExtensions +import org.opensearch.flint.spark.ppl.FlintSparkPPLParser +import org.opensearch.flint.spark.sql.FlintSparkSqlParser + +/** + * Flint Spark extension entrypoint. + */ +class FlintGenericSparkExtensions extends (SparkSessionExtensions => Unit) { + + override def apply(extensions: SparkSessionExtensions): Unit = { + extensions.injectParser { (spark, parser) => + new FlintSparkParserChain(parser,Seq(new FlintSparkPPLParser(parser),new FlintSparkSqlParser(parser))) + } + extensions.injectOptimizerRule { spark => + new FlintSparkOptimizer(spark) + } + } +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkParserChain.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkParserChain.scala new file mode 100644 index 000000000..356239cd3 --- /dev/null +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkParserChain.scala @@ -0,0 +1,56 @@ +package org.opensearch.flint.spark + +import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} +import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.types.{DataType, StructType} + +import scala.collection.mutable + +class FlintSparkParserChain (sparkParser: ParserInterface, parserChain: Seq[ParserInterface]) extends ParserInterface { + + private val parsers: mutable.ListBuffer[ParserInterface] = mutable.ListBuffer() ++= parserChain + + /** + * this method goes threw the parsers chain and try parsing sqlText - if successfully return the logical plan + * otherwise go to the next parser in the chain and try to parse the sqlText + * + * @param sqlText + * @return + */ + override def parsePlan(sqlText: String): LogicalPlan = { + try { + // go threw the parsers chain and try parsing sqlText - if successfully return the logical plan + // otherwise go to the next parser in the chain and try to parse the sqlText + for (parser <- parsers) { + try { + return parser.parsePlan(sqlText) + } catch { + case _: Exception => // Continue to the next parser + } + } + // Fall back to Spark parse plan logic if all parsers in the chain fail + sparkParser.parsePlan(sqlText) + } + } + + + override def parseExpression(sqlText: String): Expression = sparkParser.parseExpression(sqlText) + + override def parseTableIdentifier(sqlText: String): TableIdentifier = + sparkParser.parseTableIdentifier(sqlText) + + override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = + sparkParser.parseFunctionIdentifier(sqlText) + + override def parseMultipartIdentifier(sqlText: String): Seq[String] = + sparkParser.parseMultipartIdentifier(sqlText) + + override def parseTableSchema(sqlText: String): StructType = + sparkParser.parseTableSchema(sqlText) + + override def parseDataType(sqlText: String): DataType = sparkParser.parseDataType(sqlText) + + override def parseQuery(sqlText: String): LogicalPlan = sparkParser.parseQuery(sqlText) +} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala index 2a673c4bf..642a694fd 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala @@ -53,12 +53,7 @@ class FlintSparkSqlParser(sparkParser: ParserInterface) extends ParserInterface private val flintAstBuilder = new FlintSparkSqlAstBuilder() override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { flintParser => - try { - flintAstBuilder.visit(flintParser.singleStatement()) - } catch { - // Fall back to Spark parse plan logic if flint cannot parse - case _: ParseException => sparkParser.parsePlan(sqlText) - } + flintAstBuilder.visit(flintParser.singleStatement()) } override def parseExpression(sqlText: String): Expression = sparkParser.parseExpression(sqlText) diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala index 6577600c8..ee2854d01 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala @@ -11,7 +11,7 @@ import org.apache.spark.sql.flint.config.FlintConfigEntry import org.apache.spark.sql.flint.config.FlintSparkConf.HYBRID_SCAN_ENABLED import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.opensearch.flint.spark.FlintSparkExtensions +import org.opensearch.flint.spark.FlintGenericSparkExtensions trait FlintSuite extends SharedSparkSession { override protected def sparkConf = { @@ -24,7 +24,7 @@ trait FlintSuite extends SharedSparkSession { // this rule may potentially block testing of other optimization rules such as // ConstantPropagation etc. .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) - .set("spark.sql.extensions", classOf[FlintSparkExtensions].getName) + .set("spark.sql.extensions", classOf[FlintGenericSparkExtensions].getName) conf } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala index 5fc8c6745..968dff591 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala @@ -5,6 +5,7 @@ package org.opensearch.flint.spark +import org.apache.spark.FlintSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, GreaterThan, LessThanOrEqual, Literal, Not} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project} @@ -14,7 +15,7 @@ import org.apache.spark.sql.{QueryTest, Row} class FlintSparkPPLITSuite extends QueryTest with LogicalPlanTestUtils - with FlintPPLSuite + with FlintSuite with StreamTest { /** Test table and index name */ @@ -83,7 +84,9 @@ class FlintSparkPPLITSuite Row("Jane",25,"Quebec","Canada",2023,4) ) // Compare the results - assert(results === expectedResults) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical @@ -109,7 +112,9 @@ class FlintSparkPPLITSuite Row("Jane", 25) ) // Compare the results - assert(results === expectedResults) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical @@ -134,7 +139,9 @@ class FlintSparkPPLITSuite Row("Jane", 25) ) // Compare the results - assert(results === expectedResults) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) // Retrieve the logical plan @@ -163,7 +170,9 @@ class FlintSparkPPLITSuite Row("Hello", 30) ) // Compare the results - assert(results === expectedResults) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical @@ -192,7 +201,9 @@ class FlintSparkPPLITSuite Row("Jane", 25) ) // Compare the results - assert(results === expectedResults) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical @@ -219,7 +230,9 @@ class FlintSparkPPLITSuite Row("Jake", 70) ) // Compare the results - assert(results === expectedResults) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical @@ -249,7 +262,10 @@ class FlintSparkPPLITSuite ) // Compare the results - assert(results === expectedResults) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan @@ -276,7 +292,9 @@ class FlintSparkPPLITSuite ) // Compare the results - assert(results === expectedResults) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParser.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParser.scala index ea78fbdd4..0c074b3b9 100644 --- a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParser.scala +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParser.scala @@ -50,15 +50,10 @@ class FlintSparkPPLParser(sparkParser: ParserInterface) extends ParserInterface private val pplParser = new PPLSyntaxParser() override def parsePlan(sqlText: String): LogicalPlan = { - try { - // if successful build ppl logical plan and translate to catalyst logical plan - val context = new CatalystPlanContext - planTrnasormer.visit(plan(pplParser, sqlText, false), context) - context.getPlan - } catch { - // Fall back to Spark parse plan logic if flint cannot parse - case _: ParseException | _: SyntaxCheckException => sparkParser.parsePlan(sqlText) - } + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + planTrnasormer.visit(plan(pplParser, sqlText, false), context) + context.getPlan } override def parseExpression(sqlText: String): Expression = sparkParser.parseExpression(sqlText) diff --git a/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala b/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala index 9e1d36857..fa8857cde 100644 --- a/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala +++ b/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala @@ -35,7 +35,7 @@ object SQLJob { val conf: SparkConf = new SparkConf() .setAppName("SQLJob") - .set("spark.sql.extensions", "org.opensearch.flint.spark.FlintSparkExtensions") + .set("spark.sql.extensions", "org.opensearch.flint.spark.FlintGenericSparkExtensions") .set("spark.datasource.flint.host", host) .set("spark.datasource.flint.port", port) .set("spark.datasource.flint.scheme", scheme)