From 1ef53c86b16e762aa40a520d507be7b21884d2cc Mon Sep 17 00:00:00 2001 From: YANGDB Date: Wed, 6 Nov 2024 18:25:25 -0800 Subject: [PATCH] Expand ppl command (#868) * add expand command Signed-off-by: YANGDB * add expand command with visitor Signed-off-by: YANGDB * create unit / integration tests Signed-off-by: YANGDB * update expand tests Signed-off-by: YANGDB * add tests Signed-off-by: YANGDB * update doc Signed-off-by: YANGDB * update docs with examples Signed-off-by: YANGDB * update scala style Signed-off-by: YANGDB * update with additional test case remove outer generator Signed-off-by: YANGDB * update with additional test case remove outer generator Signed-off-by: YANGDB * update documentation Signed-off-by: YANGDB --------- Signed-off-by: YANGDB --- docs/ppl-lang/PPL-Example-Commands.md | 37 ++- docs/ppl-lang/README.md | 2 + docs/ppl-lang/ppl-expand-command.md | 45 +++ .../flint/spark/FlintSparkSuite.scala | 22 ++ .../ppl/FlintSparkPPLExpandITSuite.scala | 255 ++++++++++++++++ .../src/main/antlr4/OpenSearchPPLLexer.g4 | 1 + .../sql/ast/AbstractNodeVisitor.java | 4 + .../org/opensearch/sql/ast/tree/Expand.java | 44 +++ .../org/opensearch/sql/ast/tree/Flatten.java | 4 +- .../opensearch/sql/ppl/parser/AstBuilder.java | 6 + ...PlanExpandCommandTranslatorTestSuite.scala | 281 ++++++++++++++++++ 11 files changed, 686 insertions(+), 15 deletions(-) create mode 100644 docs/ppl-lang/ppl-expand-command.md create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Expand.java create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index e80f8c906..4ea564111 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -441,8 +441,30 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in _- **Limitation: another command usage of (relation) subquery is in `appendcols` commands which is unsupported**_ ---- -#### Experimental Commands: + +#### **fillnull** +[See additional command details](ppl-fillnull-command.md) +```sql + - `source=accounts | fillnull fields status_code=101` + - `source=accounts | fillnull fields request_path='/not_found', timestamp='*'` + - `source=accounts | fillnull using field1=101` + - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5` + - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5, field6 = 'N/A'` +``` + +#### **expand** +[See additional command details](ppl-expand-command.md) +```sql + - `source = table | expand field_with_array as array_list` + - `source = table | expand employee | stats max(salary) as max by state, company` + - `source = table | expand employee as worker | stats max(salary) as max by state, company` + - `source = table | expand employee as worker | eval bonus = salary * 3 | fields worker, bonus` + - `source = table | expand employee | parse description '(?.+@.+)' | fields employee, email` + - `source = table | eval array=json_array(1, 2, 3) | expand array as uid | fields name, occupation, uid` + - `source = table | expand multi_valueA as multiA | expand multi_valueB as multiB` +``` + +#### Correlation Commands: [See additional command details](ppl-correlation-command.md) ```sql @@ -454,14 +476,3 @@ _- **Limitation: another command usage of (relation) subquery is in `appendcols` > ppl-correlation-command is an experimental command - it may be removed in future versions --- -### Planned Commands: - -#### **fillnull** -[See additional command details](ppl-fillnull-command.md) -```sql - - `source=accounts | fillnull fields status_code=101` - - `source=accounts | fillnull fields request_path='/not_found', timestamp='*'` - - `source=accounts | fillnull using field1=101` - - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5` - - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5, field6 = 'N/A'` -``` diff --git a/docs/ppl-lang/README.md b/docs/ppl-lang/README.md index ef186e5f2..d72c973be 100644 --- a/docs/ppl-lang/README.md +++ b/docs/ppl-lang/README.md @@ -71,6 +71,8 @@ For additional examples see the next [documentation](PPL-Example-Commands.md). - [`correlation commands`](ppl-correlation-command.md) - [`trendline commands`](ppl-trendline-command.md) + + - [`expand commands`](ppl-expand-command.md) * **Functions** diff --git a/docs/ppl-lang/ppl-expand-command.md b/docs/ppl-lang/ppl-expand-command.md new file mode 100644 index 000000000..144c0aafa --- /dev/null +++ b/docs/ppl-lang/ppl-expand-command.md @@ -0,0 +1,45 @@ +## PPL `expand` command + +### Description +Using `expand` command to flatten a field of type: +- `Array` +- `Map` + + +### Syntax +`expand [As alias]` + +* field: to be expanded (exploded). The field must be of supported type. +* alias: Optional to be expanded as the name to be used instead of the original field name + +### Usage Guidelines +The expand command produces a row for each element in the specified array or map field, where: +- Array elements become individual rows. +- Map key-value pairs are broken into separate rows, with each key-value represented as a row. + +- When an alias is provided, the exploded values are represented under the alias instead of the original field name. +- This can be used in combination with other commands, such as stats, eval, and parse to manipulate or extract data post-expansion. + +### Examples: +- `source = table | expand employee | stats max(salary) as max by state, company` +- `source = table | expand employee as worker | stats max(salary) as max by state, company` +- `source = table | expand employee as worker | eval bonus = salary * 3 | fields worker, bonus` +- `source = table | expand employee | parse description '(?.+@.+)' | fields employee, email` +- `source = table | eval array=json_array(1, 2, 3) | expand array as uid | fields name, occupation, uid` +- `source = table | expand multi_valueA as multiA | expand multi_valueB as multiB` + +- Expand command can be used in combination with other commands such as `eval`, `stats` and more +- Using multiple expand commands will create a cartesian product of all the internal elements within each composite array or map + +### Effective SQL push-down query +The expand command is translated into an equivalent SQL operation using LATERAL VIEW explode, allowing for efficient exploding of arrays or maps at the SQL query level. + +```sql +SELECT customer exploded_productId +FROM table +LATERAL VIEW explode(productId) AS exploded_productId +``` +Where the `explode` command offers the following functionality: +- it is a column operation that returns a new column +- it creates a new row for every element in the exploded column +- internal `null`s are ignored as part of the exploded field (no row is created/exploded for null) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index c53eee548..68d370791 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -559,6 +559,28 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit |""".stripMargin) } + protected def createMultiColumnArrayTable(testTable: String): Unit = { + // CSV doesn't support struct field + sql(s""" + | CREATE TABLE $testTable + | ( + | int_col INT, + | multi_valueA Array>, + | multi_valueB Array> + | ) + | USING JSON + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES + | ( 1, array(STRUCT("1_one", 1), STRUCT(null, 11), STRUCT("1_three", null)), array(STRUCT("2_Monday", 2), null) ), + | ( 2, array(STRUCT("2_Monday", 2), null) , array(STRUCT("3_third", 3), STRUCT("3_4th", 4)) ), + | ( 3, array(STRUCT("3_third", 3), STRUCT("3_4th", 4)) , array(STRUCT("1_one", 1))), + | ( 4, null, array(STRUCT("1_one", 1))) + |""".stripMargin) + } + protected def createTableIssue112(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable ( diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala new file mode 100644 index 000000000..f0404bf7b --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala @@ -0,0 +1,255 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.flint.spark.ppl + +import java.nio.file.Files + +import scala.collection.mutable + +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, EqualTo, Explode, GeneratorOuter, Literal, Or} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLExpandITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + private val testTable = "flint_ppl_test" + private val occupationTable = "spark_catalog.default.flint_ppl_flat_table_test" + private val structNestedTable = "spark_catalog.default.flint_ppl_struct_nested_test" + private val structTable = "spark_catalog.default.flint_ppl_struct_test" + private val multiValueTable = "spark_catalog.default.flint_ppl_multi_value_test" + private val multiArraysTable = "spark_catalog.default.flint_ppl_multi_array_test" + private val tempFile = Files.createTempFile("jsonTestData", ".json") + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createNestedJsonContentTable(tempFile, testTable) + createOccupationTable(occupationTable) + createStructNestedTable(structNestedTable) + createStructTable(structTable) + createMultiValueStructTable(multiValueTable) + createMultiColumnArrayTable(multiArraysTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + override def afterAll(): Unit = { + super.afterAll() + Files.deleteIfExists(tempFile) + } + + test("expand for eval field of an array") { + val frame = sql( + s""" source = $occupationTable | eval array=json_array(1, 2, 3) | expand array as uid | fields name, occupation, uid + """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row("Jake", "Engineer", 1), + Row("Jake", "Engineer", 2), + Row("Jake", "Engineer", 3), + Row("Hello", "Artist", 1), + Row("Hello", "Artist", 2), + Row("Hello", "Artist", 3), + Row("John", "Doctor", 1), + Row("John", "Doctor", 2), + Row("John", "Doctor", 3), + Row("David", "Doctor", 1), + Row("David", "Doctor", 2), + Row("David", "Doctor", 3), + Row("David", "Unemployed", 1), + Row("David", "Unemployed", 2), + Row("David", "Unemployed", 3), + Row("Jane", "Scientist", 1), + Row("Jane", "Scientist", 2), + Row("Jane", "Scientist", 3)) + + // Compare the results + assert(results.toSet == expectedResults.toSet) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // expected plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_flat_table_test")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "array")() + val project = Project(seq(UnresolvedStar(None), aliasA), table) + val generate = Generate( + Explode(UnresolvedAttribute("array")), + seq(), + false, + None, + seq(UnresolvedAttribute("uid")), + project) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("array")), generate) + val expectedPlan = Project( + seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("occupation"), + UnresolvedAttribute("uid")), + dropSourceColumn) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("expand for structs") { + val frame = sql( + s""" source = $multiValueTable | expand multi_value AS exploded_multi_value | fields exploded_multi_value + """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(Row("1_one", 1)), + Row(Row(null, 11)), + Row(Row("1_three", null)), + Row(Row("2_Monday", 2)), + Row(null), + Row(Row("3_third", 3)), + Row(Row("3_4th", 4)), + Row(null)) + // Compare the results + assert(results.toSet == expectedResults.toSet) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // expected plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_value_test")) + val generate = Generate( + Explode(UnresolvedAttribute("multi_value")), + seq(), + outer = false, + None, + seq(UnresolvedAttribute("exploded_multi_value")), + table) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_value")), generate) + val expectedPlan = Project(Seq(UnresolvedAttribute("exploded_multi_value")), dropSourceColumn) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("expand for array of structs") { + val frame = sql(s""" + | source = $testTable + | | where country = 'England' or country = 'Poland' + | | expand bridges + | | fields bridges + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(mutable.WrappedArray.make(Array(Row(801, "Tower Bridge"), Row(928, "London Bridge")))), + Row(mutable.WrappedArray.make(Array(Row(801, "Tower Bridge"), Row(928, "London Bridge")))) + // Row(null)) -> in case of outerGenerator = GeneratorOuter(Explode(UnresolvedAttribute("bridges"))) it will include the `null` row + ) + + // Compare the results + assert(results.toSet == expectedResults.toSet) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val filter = Filter( + Or( + EqualTo(UnresolvedAttribute("country"), Literal("England")), + EqualTo(UnresolvedAttribute("country"), Literal("Poland"))), + table) + val generate = + Generate(Explode(UnresolvedAttribute("bridges")), seq(), outer = false, None, seq(), filter) + val expectedPlan = Project(Seq(UnresolvedAttribute("bridges")), generate) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("expand for array of structs with alias") { + val frame = sql(s""" + | source = $testTable + | | where country = 'England' + | | expand bridges as britishBridges + | | fields britishBridges + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(Row(801, "Tower Bridge")), + Row(Row(928, "London Bridge")), + Row(Row(801, "Tower Bridge")), + Row(Row(928, "London Bridge"))) + // Compare the results + assert(results.toSet == expectedResults.toSet) + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("flint_ppl_test")) + val filter = Filter(EqualTo(UnresolvedAttribute("country"), Literal("England")), table) + val generate = Generate( + Explode(UnresolvedAttribute("bridges")), + seq(), + outer = false, + None, + seq(UnresolvedAttribute("britishBridges")), + filter) + val dropColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("bridges")), generate) + val expectedPlan = Project(Seq(UnresolvedAttribute("britishBridges")), dropColumn) + + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("expand multi columns array table") { + val frame = sql(s""" + | source = $multiArraysTable + | | expand multi_valueA as multiA + | | expand multi_valueB as multiB + | """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row(1, Row("1_one", 1), Row("2_Monday", 2)), + Row(1, Row("1_one", 1), null), + Row(1, Row(null, 11), Row("2_Monday", 2)), + Row(1, Row(null, 11), null), + Row(1, Row("1_three", null), Row("2_Monday", 2)), + Row(1, Row("1_three", null), null), + Row(2, Row("2_Monday", 2), Row("3_third", 3)), + Row(2, Row("2_Monday", 2), Row("3_4th", 4)), + Row(2, null, Row("3_third", 3)), + Row(2, null, Row("3_4th", 4)), + Row(3, Row("3_third", 3), Row("1_one", 1)), + Row(3, Row("3_4th", 4), Row("1_one", 1))) + // Compare the results + assert(results.toSet == expectedResults.toSet) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_array_test")) + val generatorA = Explode(UnresolvedAttribute("multi_valueA")) + val generateA = + Generate(generatorA, seq(), false, None, seq(UnresolvedAttribute("multiA")), table) + val dropSourceColumnA = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueA")), generateA) + val generatorB = Explode(UnresolvedAttribute("multi_valueB")) + val generateB = Generate( + generatorB, + seq(), + false, + None, + seq(UnresolvedAttribute("multiB")), + dropSourceColumnA) + val dropSourceColumnB = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueB")), generateB) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumnB) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + + } +} diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 index 8e81ef521..e205d03d2 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLLexer.g4 @@ -37,6 +37,7 @@ KMEANS: 'KMEANS'; AD: 'AD'; ML: 'ML'; FILLNULL: 'FILLNULL'; +EXPAND: 'EXPAND'; FLATTEN: 'FLATTEN'; TRENDLINE: 'TRENDLINE'; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java index c7d61b2e2..597e9e8cc 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/AbstractNodeVisitor.java @@ -112,6 +112,10 @@ public T visitFilter(Filter node, C context) { return visitChildren(node, context); } + public T visitExpand(Expand node, C context) { + return visitChildren(node, context); + } + public T visitLookup(Lookup node, C context) { return visitChildren(node, context); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Expand.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Expand.java new file mode 100644 index 000000000..0e164ccd7 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Expand.java @@ -0,0 +1,44 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ast.tree; + +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.Node; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.UnresolvedAttribute; +import org.opensearch.sql.ast.expression.UnresolvedExpression; + +import java.util.List; +import java.util.Optional; + +/** Logical plan node of Expand */ +@RequiredArgsConstructor +public class Expand extends UnresolvedPlan { + private UnresolvedPlan child; + + @Getter + private final Field field; + @Getter + private final Optional alias; + + @Override + public Expand attach(UnresolvedPlan child) { + this.child = child; + return this; + } + + @Override + public List getChild() { + return child == null ? List.of() : List.of(child); + } + + @Override + public T accept(AbstractNodeVisitor nodeVisitor, C context) { + return nodeVisitor.visitExpand(this, context); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java index e31fbb6e3..9c57d2adf 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/Flatten.java @@ -14,7 +14,7 @@ public class Flatten extends UnresolvedPlan { private UnresolvedPlan child; @Getter - private final Field fieldToBeFlattened; + private final Field field; @Override public UnresolvedPlan attach(UnresolvedPlan child) { @@ -26,7 +26,7 @@ public UnresolvedPlan attach(UnresolvedPlan child) { public List getChild() { return child == null ? List.of() : List.of(child); } - + @Override public T accept(AbstractNodeVisitor nodeVisitor, C context) { return nodeVisitor.visitFlatten(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 36a34cd06..f6581016f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -131,6 +131,12 @@ public UnresolvedPlan visitWhereCommand(OpenSearchPPLParser.WhereCommandContext return new Filter(internalVisitExpression(ctx.logicalExpression())); } + @Override + public UnresolvedPlan visitExpandCommand(OpenSearchPPLParser.ExpandCommandContext ctx) { + return new Expand((Field) internalVisitExpression(ctx.fieldExpression()), + ctx.alias!=null ? Optional.of(internalVisitExpression(ctx.alias)) : Optional.empty()); + } + @Override public UnresolvedPlan visitCorrelateCommand(OpenSearchPPLParser.CorrelateCommandContext ctx) { return new Correlation(ctx.correlationType().getText(), diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala new file mode 100644 index 000000000..2acaac529 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala @@ -0,0 +1,281 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.FlattenGenerator +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Explode, GeneratorOuter, Literal, RegExpExtract} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, Generate, Project} +import org.apache.spark.sql.types.IntegerType + +class PPLLogicalPlanExpandCommandTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test expand only field") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit(plan(pplParser, "source=relation | expand field_with_array"), context) + + val relation = UnresolvedRelation(Seq("relation")) + val generator = Explode(UnresolvedAttribute("field_with_array")) + val generate = Generate(generator, seq(), false, None, seq(), relation) + val expectedPlan = Project(seq(UnresolvedStar(None)), generate) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("expand multi columns array table") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + s""" + | source = table + | | expand multi_valueA as multiA + | | expand multi_valueB as multiB + | """.stripMargin), + context) + + val relation = UnresolvedRelation(Seq("table")) + val generatorA = Explode(UnresolvedAttribute("multi_valueA")) + val generateA = + Generate(generatorA, seq(), false, None, seq(UnresolvedAttribute("multiA")), relation) + val dropSourceColumnA = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueA")), generateA) + val generatorB = Explode(UnresolvedAttribute("multi_valueB")) + val generateB = Generate( + generatorB, + seq(), + false, + None, + seq(UnresolvedAttribute("multiB")), + dropSourceColumnA) + val dropSourceColumnB = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueB")), generateB) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumnB) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand on array field which is eval array=json_array") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eval array=json_array(1, 2, 3) | expand array as uid | fields uid"), + context) + + val relation = UnresolvedRelation(Seq("table")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "array")() + val project = Project(seq(UnresolvedStar(None), aliasA), relation) + val generate = Generate( + Explode(UnresolvedAttribute("array")), + seq(), + false, + None, + seq(UnresolvedAttribute("uid")), + project) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("array")), generate) + val expectedPlan = Project(seq(UnresolvedAttribute("uid")), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand only field with alias") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=relation | expand field_with_array as array_list "), + context) + + val relation = UnresolvedRelation(Seq("relation")) + val generate = Generate( + Explode(UnresolvedAttribute("field_with_array")), + seq(), + false, + None, + seq(UnresolvedAttribute("array_list")), + relation) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("field_with_array")), generate) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and stats") { + val context = new CatalystPlanContext + val query = + "source = table | expand employee | stats max(salary) as max by state, company" + val logPlan = + planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("table")) + val generate = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) + val average = Alias( + UnresolvedFunction(seq("MAX"), seq(UnresolvedAttribute("salary")), false, None, false), + "max")() + val state = Alias(UnresolvedAttribute("state"), "state")() + val company = Alias(UnresolvedAttribute("company"), "company")() + val groupingState = Alias(UnresolvedAttribute("state"), "state")() + val groupingCompany = Alias(UnresolvedAttribute("company"), "company")() + val aggregate = + Aggregate(Seq(groupingState, groupingCompany), Seq(average, state, company), generate) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and stats with alias") { + val context = new CatalystPlanContext + val query = + "source = table | expand employee as workers | stats max(salary) as max by state, company" + val logPlan = + planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("table")) + val generate = Generate( + Explode(UnresolvedAttribute("employee")), + seq(), + false, + None, + seq(UnresolvedAttribute("workers")), + table) + val dropSourceColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val average = Alias( + UnresolvedFunction(seq("MAX"), seq(UnresolvedAttribute("salary")), false, None, false), + "max")() + val state = Alias(UnresolvedAttribute("state"), "state")() + val company = Alias(UnresolvedAttribute("company"), "company")() + val groupingState = Alias(UnresolvedAttribute("state"), "state")() + val groupingCompany = Alias(UnresolvedAttribute("company"), "company")() + val aggregate = Aggregate( + Seq(groupingState, groupingCompany), + Seq(average, state, company), + dropSourceColumn) + val expectedPlan = Project(Seq(UnresolvedStar(None)), aggregate) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and eval") { + val context = new CatalystPlanContext + val query = "source = table | expand employee | eval bonus = salary * 3" + val logPlan = planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("table")) + val generate = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) + val bonusProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "*", + Seq(UnresolvedAttribute("salary"), Literal(3, IntegerType)), + isDistinct = false), + "bonus")()), + generate) + val expectedPlan = Project(Seq(UnresolvedStar(None)), bonusProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and eval with fields and alias") { + val context = new CatalystPlanContext + val query = + "source = table | expand employee as worker | eval bonus = salary * 3 | fields worker, bonus " + val logPlan = planTransformer.visit(plan(pplParser, query), context) + val table = UnresolvedRelation(Seq("table")) + val generate = Generate( + Explode(UnresolvedAttribute("employee")), + seq(), + false, + None, + seq(UnresolvedAttribute("worker")), + table) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) + val bonusProject = Project( + Seq( + UnresolvedStar(None), + Alias( + UnresolvedFunction( + "*", + Seq(UnresolvedAttribute("salary"), Literal(3, IntegerType)), + isDistinct = false), + "bonus")()), + dropSourceColumn) + val expectedPlan = + Project(Seq(UnresolvedAttribute("worker"), UnresolvedAttribute("bonus")), bonusProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and parse and fields") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=table | expand employee | parse description '(?.+@.+)' | fields employee, email"), + context) + val table = UnresolvedRelation(Seq("table")) + val generator = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) + val emailAlias = + Alias( + RegExpExtract(UnresolvedAttribute("description"), Literal("(?.+@.+)"), Literal(1)), + "email")() + val parseProject = Project( + Seq(UnresolvedAttribute("description"), emailAlias, UnresolvedStar(None)), + generator) + val expectedPlan = + Project(Seq(UnresolvedAttribute("employee"), UnresolvedAttribute("email")), parseProject) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test expand and parse and flatten ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=relation | expand employee | parse description '(?.+@.+)' | flatten roles "), + context) + val table = UnresolvedRelation(Seq("relation")) + val generateEmployee = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) + val emailAlias = + Alias( + RegExpExtract(UnresolvedAttribute("description"), Literal("(?.+@.+)"), Literal(1)), + "email")() + val parseProject = Project( + Seq(UnresolvedAttribute("description"), emailAlias, UnresolvedStar(None)), + generateEmployee) + val generateRoles = Generate( + GeneratorOuter(new FlattenGenerator(UnresolvedAttribute("roles"))), + seq(), + true, + None, + seq(), + parseProject) + val dropSourceColumnRoles = + DataFrameDropColumns(Seq(UnresolvedAttribute("roles")), generateRoles) + val expectedPlan = Project(Seq(UnresolvedStar(None)), dropSourceColumnRoles) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + +}