Skip to content

Commit

Permalink
Find mv query in origin
Browse files Browse the repository at this point in the history
Signed-off-by: Chen Dai <[email protected]>
  • Loading branch information
dai-chen committed Oct 17, 2023
1 parent 0fc86f6 commit c5efc62
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,22 @@ materializedViewStatement

createMaterializedViewStatement
: CREATE MATERIALIZED VIEW mvName=multipartIdentifier
AS mvQuery=.*
AS query=materializedViewQuery
(WITH LEFT_PAREN propertyList RIGHT_PAREN)?
;

dropMaterializedViewStatement
: DROP MATERIALIZED VIEW mvName=multipartIdentifier
;

/*
* Match all remaining tokens in non-greedy way
* so WITH clause won't be captured by this rule.
*/
materializedViewQuery
: .+?
;

indexColTypeList
: indexColType (COMMA indexColType)*
;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.{FunctionIdentifier, TableIdentifier}
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.parser._
import org.apache.spark.sql.catalyst.parser.ParserUtils.withOrigin
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.trees.Origin
import org.apache.spark.sql.types.{DataType, StructType}
Expand All @@ -54,7 +55,10 @@ class FlintSparkSqlParser(sparkParser: ParserInterface) extends ParserInterface

override def parsePlan(sqlText: String): LogicalPlan = parse(sqlText) { flintParser =>
try {
flintAstBuilder.visit(flintParser.singleStatement())
val ctx = flintParser.singleStatement()
withOrigin(ctx, Some(sqlText)) {
flintAstBuilder.visit(ctx)
}
} catch {
// Fall back to Spark parse plan logic if flint cannot parse
case _: ParseException => sparkParser.parsePlan(sqlText)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@

package org.opensearch.flint.spark.sql.mv

import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.{CreateMaterializedViewStatementContext, DropMaterializedViewStatementContext}
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsParser.{CreateMaterializedViewStatementContext, DropMaterializedViewStatementContext, MaterializedViewQueryContext}
import org.opensearch.flint.spark.sql.FlintSparkSqlExtensionsVisitor

import org.apache.spark.sql.catalyst.trees.CurrentOrigin

/**
* Flint Spark AST builder that builds Spark command for Flint materialized view statement.
*/
Expand All @@ -16,12 +18,20 @@ trait FlintSparkMaterializedViewAstBuilder extends FlintSparkSqlExtensionsVisito
override def visitCreateMaterializedViewStatement(
ctx: CreateMaterializedViewStatementContext): AnyRef = {
val mvName = ctx.mvName.getText
val query = ctx.mvQuery.getText
val query = getMvQuery(ctx.query)
throw new UnsupportedOperationException(s"Create MV $mvName with query $query")
}

override def visitDropMaterializedViewStatement(
ctx: DropMaterializedViewStatementContext): AnyRef = {
throw new UnsupportedOperationException(s"Drop MV ${ctx.mvName.getText}")
}

private def getMvQuery(ctx: MaterializedViewQueryContext): String = {
// Assume origin must be preserved at the beginning of parsing
val sqlText = CurrentOrigin.get.sqlText.get
val startIndex = ctx.getStart.getStartIndex
val stopIndex = ctx.getStop.getStopIndex
sqlText.substring(startIndex, stopIndex + 1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ class FlintSparkSqlSuite extends FlintSuite {
| SELECT elb, COUNT(*)
| FROM alb_logs
| GROUP BY TUMBLE(time, '1 Minute')
| WITH (
| auto_refresh = true
| )
|""".stripMargin)

the[UnsupportedOperationException] thrownBy
Expand Down

0 comments on commit c5efc62

Please sign in to comment.