diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintGenericSparkExtensions.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintGenericSparkExtensions.scala deleted file mode 100644 index 6a890dde3..000000000 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintGenericSparkExtensions.scala +++ /dev/null @@ -1,25 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.flint.spark - -import org.apache.spark.sql.SparkSessionExtensions -import org.opensearch.flint.spark.ppl.FlintSparkPPLParser -import org.opensearch.flint.spark.sql.FlintSparkSqlParser - -/** - * Flint Spark extension entrypoint. - */ -class FlintGenericSparkExtensions extends (SparkSessionExtensions => Unit) { - - override def apply(extensions: SparkSessionExtensions): Unit = { - extensions.injectParser { (spark, parser) => - new FlintSparkParserChain(parser,Seq(new FlintSparkPPLParser(parser),new FlintSparkSqlParser(parser))) - } - extensions.injectOptimizerRule { spark => - new FlintSparkOptimizer(spark) - } - } -} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkParserChain.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkParserChain.scala deleted file mode 100644 index 356239cd3..000000000 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/FlintSparkParserChain.scala +++ /dev/null @@ -1,56 +0,0 @@ -package org.opensearch.flint.spark - -import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.parser.{ParseException, ParserInterface} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.types.{DataType, StructType} - -import scala.collection.mutable - -class FlintSparkParserChain (sparkParser: ParserInterface, parserChain: Seq[ParserInterface]) extends ParserInterface { - - private val parsers: mutable.ListBuffer[ParserInterface] = mutable.ListBuffer() ++= parserChain - - /** - * this method goes threw the parsers chain and try parsing sqlText - if successfully return the logical plan - * otherwise go to the next parser in the chain and try to parse the sqlText - * - * @param sqlText - * @return - */ - override def parsePlan(sqlText: String): LogicalPlan = { - try { - // go threw the parsers chain and try parsing sqlText - if successfully return the logical plan - // otherwise go to the next parser in the chain and try to parse the sqlText - for (parser <- parsers) { - try { - return parser.parsePlan(sqlText) - } catch { - case _: Exception => // Continue to the next parser - } - } - // Fall back to Spark parse plan logic if all parsers in the chain fail - sparkParser.parsePlan(sqlText) - } - } - - - override def parseExpression(sqlText: String): Expression = sparkParser.parseExpression(sqlText) - - override def parseTableIdentifier(sqlText: String): TableIdentifier = - sparkParser.parseTableIdentifier(sqlText) - - override def parseFunctionIdentifier(sqlText: String): FunctionIdentifier = - sparkParser.parseFunctionIdentifier(sqlText) - - override def parseMultipartIdentifier(sqlText: String): Seq[String] = - sparkParser.parseMultipartIdentifier(sqlText) - - override def parseTableSchema(sqlText: String): StructType = - sparkParser.parseTableSchema(sqlText) - - override def parseDataType(sqlText: String): DataType = sparkParser.parseDataType(sqlText) - - override def parseQuery(sqlText: String): LogicalPlan = sparkParser.parseQuery(sqlText) -} diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala index 642a694fd..2a673c4bf 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala @@ -53,7 +53,12 @@ class FlintSparkSqlParser(sparkParser: ParserInterface) extends ParserInterface private val flintAstBuilder = new FlintSparkSqlAstBuilder() override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { flintParser => - flintAstBuilder.visit(flintParser.singleStatement()) + try { + flintAstBuilder.visit(flintParser.singleStatement()) + } catch { + // Fall back to Spark parse plan logic if flint cannot parse + case _: ParseException => sparkParser.parsePlan(sqlText) + } } override def parseExpression(sqlText: String): Expression = sparkParser.parseExpression(sqlText) diff --git a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala index ee2854d01..6577600c8 100644 --- a/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala +++ b/flint-spark-integration/src/test/scala/org/apache/spark/FlintSuite.scala @@ -11,7 +11,7 @@ import org.apache.spark.sql.flint.config.FlintConfigEntry import org.apache.spark.sql.flint.config.FlintSparkConf.HYBRID_SCAN_ENABLED import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession -import org.opensearch.flint.spark.FlintGenericSparkExtensions +import org.opensearch.flint.spark.FlintSparkExtensions trait FlintSuite extends SharedSparkSession { override protected def sparkConf = { @@ -24,7 +24,7 @@ trait FlintSuite extends SharedSparkSession { // this rule may potentially block testing of other optimization rules such as // ConstantPropagation etc. .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) - .set("spark.sql.extensions", classOf[FlintGenericSparkExtensions].getName) + .set("spark.sql.extensions", classOf[FlintSparkExtensions].getName) conf } diff --git a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala index 968dff591..53f6b0c08 100644 --- a/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala +++ b/integ-test/src/test/scala/org/opensearch/flint/spark/FlintSparkPPLITSuite.scala @@ -15,7 +15,7 @@ import org.apache.spark.sql.{QueryTest, Row} class FlintSparkPPLITSuite extends QueryTest with LogicalPlanTestUtils - with FlintSuite + with FlintPPLSuite with StreamTest { /** Test table and index name */ @@ -309,16 +309,11 @@ class FlintSparkPPLITSuite } test("create ppl simple age avg group by country query test ") { - val checkData = sql(s"SELECT country, AVG(age) AS avg_age FROM $testTable group by country"); - checkData.show() - checkData.queryExecution.logical.show() - val frame = sql( s""" | source = $testTable| stats avg(age) by country | """.stripMargin) - // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results diff --git a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParser.scala b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParser.scala index 0c074b3b9..ea78fbdd4 100644 --- a/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParser.scala +++ b/ppl-spark-integration/src/main/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParser.scala @@ -50,10 +50,15 @@ class FlintSparkPPLParser(sparkParser: ParserInterface) extends ParserInterface private val pplParser = new PPLSyntaxParser() override def parsePlan(sqlText: String): LogicalPlan = { - // if successful build ppl logical plan and translate to catalyst logical plan - val context = new CatalystPlanContext - planTrnasormer.visit(plan(pplParser, sqlText, false), context) - context.getPlan + try { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + planTrnasormer.visit(plan(pplParser, sqlText, false), context) + context.getPlan + } catch { + // Fall back to Spark parse plan logic if flint cannot parse + case _: ParseException | _: SyntaxCheckException => sparkParser.parsePlan(sqlText) + } } override def parseExpression(sqlText: String): Expression = sparkParser.parseExpression(sqlText) diff --git a/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala b/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala index fa8857cde..9e1d36857 100644 --- a/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala +++ b/spark-sql-application/src/main/scala/org/opensearch/sql/SQLJob.scala @@ -35,7 +35,7 @@ object SQLJob { val conf: SparkConf = new SparkConf() .setAppName("SQLJob") - .set("spark.sql.extensions", "org.opensearch.flint.spark.FlintGenericSparkExtensions") + .set("spark.sql.extensions", "org.opensearch.flint.spark.FlintSparkExtensions") .set("spark.datasource.flint.host", host) .set("spark.datasource.flint.port", port) .set("spark.datasource.flint.scheme", scheme)