Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add explain command #687

Merged
merged 10 commits into from
Sep 27, 2024
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,10 @@ package org.opensearch.flint.spark.ppl
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.TableIdentifier
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, EqualTo, IsNotNull, Literal, Not, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.command.DescribeTableCommand
import org.apache.spark.sql.execution.ExplainMode
import org.apache.spark.sql.execution.command.{DescribeTableCommand, ExplainCommand}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLBasicITSuite
Expand Down Expand Up @@ -38,6 +39,100 @@ class FlintSparkPPLBasicITSuite
}
}

test("explain simple mode test") {
val frame = sql(s"""
| explain simple | source = $testTable | where state != 'California' | fields name
| """.stripMargin)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val relation = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val filter =
Filter(Not(EqualTo(UnresolvedAttribute("state"), Literal("California"))), relation)
val expectedPlan: LogicalPlan =
ExplainCommand(
Project(Seq(UnresolvedAttribute("name")), filter),
ExplainMode.fromString("simple"))
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("explain extended mode test") {
val frame = sql(s"""
| explain extended | source = $testTable
| """.stripMargin)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val relation = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val expectedPlan: LogicalPlan =
ExplainCommand(
Project(Seq(UnresolvedStar(None)), relation),
ExplainMode.fromString("extended"))
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("explain codegen mode test") {
val frame = sql(s"""
| explain codegen | source = $testTable | dedup name | fields name, state
| """.stripMargin)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val relation = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val nameAttribute = UnresolvedAttribute("name")
val dedup =
Deduplicate(Seq(nameAttribute), Filter(IsNotNull(nameAttribute), relation))
val expectedPlan: LogicalPlan =
ExplainCommand(
Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("state")), dedup),
ExplainMode.fromString("codegen"))
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("explain cost mode test") {
val frame = sql(s"""
| explain cost | source = $testTable | sort name | fields name, age
| """.stripMargin)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val relation = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val sort: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("name"), Ascending)), global = true, relation)
val expectedPlan: LogicalPlan =
ExplainCommand(
Project(Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")), sort),
ExplainMode.fromString("cost"))
// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("explain formatted mode test") {
val frame = sql(s"""
| explain formatted | source = $testTable | fields - name
| """.stripMargin)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val relation = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val dropColumns = DataFrameDropColumns(Seq(UnresolvedAttribute("name")), relation)
val expectedPlan: LogicalPlan =
ExplainCommand(
Project(Seq(UnresolvedStar(Option.empty)), dropColumns),
ExplainMode.fromString("formatted"))

// Compare the two plans
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("describe (extended) table query test") {
val frame = sql(s"""
describe flint_ppl_test
Expand Down
8 changes: 8 additions & 0 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,14 @@ See the next samples of PPL queries :
**Describe**
- `describe table` This command is equal to the `DESCRIBE EXTENDED table` SQL command

**Explain**
- `explain simple | source = table | where a = 1 | fields a,b,c`
- `explain extended | source = table`
- `explain codegen | source = table | dedup a | fields a,b,c`
- `explain cost | source = table | sort a | fields a,b,c`
- `explain formatted | source = table | fields - a`
- `explain simple | describe table`

**Fields**
- `source = table`
- `source = table | fields a,b,c`
Expand Down
8 changes: 8 additions & 0 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,14 @@ APPROXIMATE: 'APPROXIMATE';
SCOPE: 'SCOPE';
MAPPING: 'MAPPING';

//EXPLAIN KEYWORDS
EXPLAIN: 'EXPLAIN';
FORMATTED: 'FORMATTED';
COST: 'COST';
CODEGEN: 'CODEGEN';
EXTENDED: 'EXTENDED';
SIMPLE: 'SIMPLE';

// COMMAND ASSIST KEYWORDS
AS: 'AS';
BY: 'BY';
Expand Down
15 changes: 14 additions & 1 deletion ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pplStatement
;

dmlStatement
: queryStatement
: (explainCommand PIPE)? queryStatement
;

queryStatement
Expand Down Expand Up @@ -58,6 +58,18 @@ describeCommand
: DESCRIBE tableSourceClause
;

explainCommand
: EXPLAIN explainMode
;

explainMode
: FORMATTED
| COST
| CODEGEN
| EXTENDED
| SIMPLE
;

showDataSourcesCommand
: SHOW DATASOURCES
;
Expand Down Expand Up @@ -915,6 +927,7 @@ keywordsCanBeId
| KMEANS
| AD
| ML
| EXPLAIN
// commands assist keywords
| SOURCE
| INDEX
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,23 @@
public class Explain extends Statement {

private final Statement statement;
private final ExplainMode explainMode;

public Explain(Query statement, String explainMode) {
this.statement = statement;
this.explainMode = ExplainMode.valueOf(explainMode);
}

@Override
public <R, C> R accept(AbstractNodeVisitor<R, C> visitor, C context) {
return visitor.visitExplain(this, context);
}

public enum ExplainMode {
formatted,
cost,
codegen,
extended,
simple
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
import org.apache.spark.sql.catalyst.plans.logical.Limit;
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan;
import org.apache.spark.sql.catalyst.plans.logical.Project$;
import org.apache.spark.sql.execution.ExplainMode;
import org.apache.spark.sql.execution.command.DescribeTableCommand;
import org.apache.spark.sql.execution.command.ExplainCommand;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
import org.opensearch.sql.ast.AbstractNodeVisitor;
Expand Down Expand Up @@ -131,7 +133,8 @@ public LogicalPlan visitQuery(Query node, CatalystPlanContext context) {

@Override
public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) {
return node.getStatement().accept(this, context);
node.getStatement().accept(this, context);
return context.apply(p -> new ExplainCommand(p, ExplainMode.fromString(node.getExplainMode().name())));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ public AstStatementBuilder(AstBuilder astBuilder, StatementBuilderContext contex
@Override
public Statement visitDmlStatement(OpenSearchPPLParser.DmlStatementContext ctx) {
Query query = new Query(addSelectAll(astBuilder.visit(ctx)), context.getFetchSize());
return context.isExplain ? new Explain(query) : query;
OpenSearchPPLParser.ExplainCommandContext explainContext = ctx.explainCommand();
if (explainContext != null) {
return new Explain(query, explainContext.explainMode().getText());
}
return query;
}

@Override
Expand All @@ -52,22 +56,15 @@ public StatementBuilderContext getContext() {
}

public static class StatementBuilderContext {
private boolean isExplain;
private int fetchSize;

public StatementBuilderContext(boolean isExplain, int fetchSize) {
this.isExplain = isExplain;
public StatementBuilderContext(int fetchSize) {
this.fetchSize = fetchSize;
}

public static StatementBuilderContext builder() {
//todo set the default statement builder init params configurable
return new StatementBuilderContext(false,1000);
}

public StatementBuilderContext explain(boolean isExplain) {
this.isExplain = isExplain;
return this;
return new StatementBuilderContext(1000);
}

public int getFetchSize() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class FlintSparkPPLParser(sparkParser: ParserInterface) extends ParserInterface
try {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
planTransformer.visit(plan(pplParser, sqlText, false), context)
planTransformer.visit(plan(pplParser, sqlText), context)
context.getPlan
} catch {
// Fall back to Spark parse plan logic if flint cannot parse
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class PPLSyntaxParser extends Parser {
}

object PlaneUtils {
def plan(parser: PPLSyntaxParser, query: String, isExplain: Boolean): Statement = {
def plan(parser: PPLSyntaxParser, query: String): Statement = {
val builder = new AstStatementBuilder(
new AstBuilder(new AstExpressionBuilder(), query),
AstStatementBuilder.StatementBuilderContext.builder())
Expand Down
Loading
Loading