diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala index e9ccc5b5a..3c94975fd 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLBasicITSuite.scala @@ -8,7 +8,7 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, EqualTo, Literal, Not, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, EqualTo, IsNotNull, Literal, Not, SortOrder} import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.ExplainMode import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExplainCommand} @@ -39,9 +39,9 @@ class FlintSparkPPLBasicITSuite } } - test("explain test") { + test("explain simple mode test") { val frame = sql(s""" - | explain | source = $testTable | where state != 'California' | fields name + | explain simple | source = $testTable | where state != 'California' | fields name | """.stripMargin) // Retrieve the logical plan @@ -55,7 +55,87 @@ class FlintSparkPPLBasicITSuite Project(Seq(UnresolvedAttribute("name")), filter), ExplainMode.fromString("simple")) // Compare the two plans - assert(expectedPlan === logicalPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("explain extended mode test") { + val frame = sql(s""" + | explain extended | source = $testTable + | """.stripMargin) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val relation = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val expectedPlan: LogicalPlan = + ExplainCommand( + Project(Seq(UnresolvedStar(None)), relation), + ExplainMode.fromString("extended")) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("explain codegen mode test") { + val frame = sql(s""" + | explain codegen | source = $testTable | dedup name | fields name, state + | """.stripMargin) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val relation = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val nameAttribute = UnresolvedAttribute("name") + val dedup = + Deduplicate(Seq(nameAttribute), Filter(IsNotNull(nameAttribute), relation)) + val expectedPlan: LogicalPlan = + ExplainCommand( + Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("state")), dedup), + ExplainMode.fromString("codegen")) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("explain cost mode test") { + val frame = sql(s""" + | explain cost | source = $testTable | sort name | fields name, age + | """.stripMargin) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val relation = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val sort: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), + global = true, + relation) + val expectedPlan: LogicalPlan = + ExplainCommand( + Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), sort), + ExplainMode.fromString("cost")) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("explain formatted mode test") { + val frame = sql(s""" + | explain formatted | source = $testTable | fields - name + | """.stripMargin) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val relation = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val dropColumns = DataFrameDropColumns( + Seq(UnresolvedAttribute("name")), relation + ) + val expectedPlan: LogicalPlan = + ExplainCommand( + Project(Seq(UnresolvedStar(Option.empty)), dropColumns), + ExplainMode.fromString("formatted")) + + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } test("describe (extended) table query test") { diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index ad11f96d5..41c171c36 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -228,8 +228,12 @@ See the next samples of PPL queries : - `describe table` This command is equal to the `DESCRIBE EXTENDED table` SQL command **Explain** - - `explain | source = table | where a = 1 | fields a,b,c` - - `explain | describe table` + - `explain simple | source = table | where a = 1 | fields a,b,c` + - `explain extended | source = table` + - `explain codegen | source = table | dedup a | fields a,b,c` + - `explain cost | source = table | sort a | fields a,b,c` + - `explain formatted | source = table | fields - a` + - `explain simple | describe table` **Fields** - `source = table` diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index d837e5e0e..6c1f3512d 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -35,7 +35,6 @@ NEW_FIELD: 'NEW_FIELD'; KMEANS: 'KMEANS'; AD: 'AD'; ML: 'ML'; -EXPLAIN: 'EXPLAIN'; //Native JOIN KEYWORDS JOIN: 'JOIN'; @@ -57,6 +56,14 @@ APPROXIMATE: 'APPROXIMATE'; SCOPE: 'SCOPE'; MAPPING: 'MAPPING'; +//EXPLAIN KEYWORDS +EXPLAIN: 'EXPLAIN'; +FORMATTED: 'FORMATTED'; +COST: 'COST'; +CODEGEN: 'CODEGEN'; +EXTENDED: 'EXTENDED'; +SIMPLE: 'SIMPLE'; + // COMMAND ASSIST KEYWORDS AS: 'AS'; BY: 'BY'; diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index 25f2f52e4..99ec803c6 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -59,7 +59,15 @@ describeCommand ; explainCommand - : EXPLAIN + : EXPLAIN explainMode + ; + +explainMode + : FORMATTED + | COST + | CODEGEN + | EXTENDED + | SIMPLE ; showDataSourcesCommand diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/statement/Explain.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/statement/Explain.java index 4968668ac..9c961b9e6 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/statement/Explain.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/statement/Explain.java @@ -14,17 +14,31 @@ public class Explain extends Statement { private Statement statement; + private ExplainMode explainMode; - public Explain(Query statement) { + public Explain(Query statement, String explainMode) { this.statement = statement; + this.explainMode = ExplainMode.valueOf(explainMode); } public Statement getStatement() { return statement; } + public ExplainMode getExplainMode() { + return explainMode; + } + @Override public R accept(AbstractNodeVisitor visitor, C context) { return visitor.visitExplain(this, context); } + + public enum ExplainMode { + formatted, + cost, + codegen, + extended, + simple + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 6f565ba80..bf922f6f7 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -134,7 +134,7 @@ public LogicalPlan visitQuery(Query node, CatalystPlanContext context) { @Override public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { node.getStatement().accept(this, context); - return context.apply(p -> new ExplainCommand(p, ExplainMode.fromString("simple"))); + return context.apply(p -> new ExplainCommand(p, ExplainMode.fromString(node.getExplainMode().name()))); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java index 19878b15d..6a545f091 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstStatementBuilder.java @@ -34,9 +34,12 @@ public AstStatementBuilder(AstBuilder astBuilder, StatementBuilderContext contex @Override public Statement visitDmlStatement(OpenSearchPPLParser.DmlStatementContext ctx) { - boolean explain = ctx.explainCommand() != null; Query query = new Query(addSelectAll(astBuilder.visit(ctx)), context.getFetchSize()); - return explain ? new Explain(query) : query; + OpenSearchPPLParser.ExplainCommandContext explainContext = ctx.explainCommand(); + if (explainContext != null) { + return new Explain(query, explainContext.explainMode().getText()); + } + return query; } @Override diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala index e54fbeaa7..30deecc31 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanBasicQueriesTranslatorTestSuite.scala @@ -281,7 +281,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite test("test fields + field list") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, "source=t | sort - A | fields + A, B | head 5", false), + plan(pplParser, "source=t | sort - A | fields + A, B | head 5"), context) val table = UnresolvedRelation(Seq("t")) @@ -298,7 +298,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite test("test fields - field list") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, "source=t | sort - A | fields - A, B | head 5", false), + plan(pplParser, "source=t | sort - A | fields - A, B | head 5"), context) val table = UnresolvedRelation(Seq("t"))