diff --git a/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 b/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 index 44f873f73..3f999aa08 100644 --- a/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 +++ b/flint-spark-integration/src/main/antlr4/FlintSparkSqlExtensions.g4 @@ -84,7 +84,7 @@ materializedViewStatement createMaterializedViewStatement : CREATE MATERIALIZED VIEW mvName=multipartIdentifier - AS mvQuery=.* + AS query=materializedViewQuery (WITH LEFT_PAREN propertyList RIGHT_PAREN)? ; @@ -92,6 +92,14 @@ dropMaterializedViewStatement : DROP MATERIALIZED VIEW mvName=multipartIdentifier ; +/* + * Match all remaining tokens in non-greedy way + * so WITH clause won't be captured by this rule. + */ +materializedViewQuery + : .+? + ; + indexColTypeList : indexColType (COMMA indexColType)* ; diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala index 78a9c0628..bb4f1e127 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/FlintSparkSqlParser.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.parser._ +import org.apache.spark.sql.catalyst.parser.ParserUtils.withOrigin import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.trees.Origin import org.apache.spark.sql.types.{DataType, StructType} @@ -54,7 +55,10 @@ class FlintSparkSqlParser(sparkParser: ParserInterface) extends ParserInterface override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { flintParser => try { - flintAstBuilder.visit(flintParser.singleStatement()) + val ctx = flintParser.singleStatement() + withOrigin(ctx, Some(sqlText)) { + flintAstBuilder.visit(ctx) + } } catch { // Fall back to Spark parse plan logic if flint cannot parse case _: ParseException => sparkParser.parsePlan(sqlText) diff --git a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala index 82804bb50..95541f83f 100644 --- a/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala +++ b/flint-spark-integration/src/main/scala/org/opensearch/flint/spark/sql/mv/FlintSparkMaterializedViewAstBuilder.scala @@ -5,9 +5,11 @@ package org.opensearch.flint.spark.sql.mv -import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.{CreateMaterializedViewStatementContext, DropMaterializedViewStatementContext} +import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.{CreateMaterializedViewStatementContext, DropMaterializedViewStatementContext, MaterializedViewQueryContext} import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsVisitor +import org.apache.spark.sql.catalyst.trees.CurrentOrigin + /** * Flint Spark AST builder that builds Spark command for Flint materialized view statement. */ @@ -16,7 +18,7 @@ trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisito override def visitCreateMaterializedViewStatement( ctx: CreateMaterializedViewStatementContext): AnyRef = { val mvName = ctx.mvName.getText - val query = ctx.mvQuery.getText + val query = getMvQuery(ctx.query) throw new UnsupportedOperationException(s"Create MV $mvName with query $query") } @@ -24,4 +26,12 @@ trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisito ctx: DropMaterializedViewStatementContext): AnyRef = { throw new UnsupportedOperationException(s"Drop MV ${ctx.mvName.getText}") } + + private def getMvQuery(ctx: MaterializedViewQueryContext): String = { + // Assume origin must be preserved at the beginning of parsing + val sqlText = CurrentOrigin.get.sqlText.get + val startIndex = ctx.getStart.getStartIndex + val stopIndex = ctx.getStop.getStopIndex + sqlText.substring(startIndex, stopIndex + 1) + } } diff --git a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/sql/FlintSparkSqlSuite.scala b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/sql/FlintSparkSqlSuite.scala index 49f56fbed..a043db8c8 100644 --- a/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/sql/FlintSparkSqlSuite.scala +++ b/flint-spark-integration/src/test/scala/org/opensearch/flint/spark/sql/FlintSparkSqlSuite.scala @@ -19,6 +19,9 @@ class FlintSparkSqlSuite extends FlintSuite { | SELECT elb, COUNT(*) | FROM alb_logs | GROUP BY TUMBLE(time, '1 Minute') + | WITH ( + | auto_refresh = true + | ) |""".stripMargin) the[UnsupportedOperationException] thrownBy