From a91a9867afffe8aaf65ecf4b89208d00ab21e652 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Tue, 5 Nov 2024 14:51:50 -0800 Subject: [PATCH] update with additional test case remove outer generator Signed-off-by: YANGDB --- docs/ppl-lang/PPL-Example-Commands.md | 9 +- docs/ppl-lang/ppl-expand-command.md | 9 -- .../ppl/FlintSparkPPLExpandITSuite.scala | 135 ++++++++++-------- .../sql/ppl/CatalystQueryPlanVisitor.java | 6 +- ...PlanExpandCommandTranslatorTestSuite.scala | 81 ++++++----- 5 files changed, 126 insertions(+), 114 deletions(-) diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 11709b32b..5d4f68cb6 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -451,11 +451,12 @@ _- **Limitation: another command usage of (relation) subquery is in `appendcols` #### **expand** [See additional command details](ppl-expand-command.md) ```sql - - `source= table | expand field_with_array as array_list` + - `source = table | expand field_with_array as array_list` - `source = table | expand employee | stats max(salary) as max by state, company` - - `source = table | expand employee as worker | stats max(salary) as max by state, company` - - `source = table | expand employee as worker | eval bonus = salary * 3 | fields worker, bonus` - - `source = table | expand employee | parse description '(?.+@.+)' | fields employee, email` + - `source = table | expand employee as worker | stats max(salary) as max by state, company` + - `source = table | expand employee as worker | eval bonus = salary * 3 | fields worker, bonus` + - `source = table | expand employee | parse description '(?.+@.+)' | fields employee, email` + - `source = table | eval array=json_array(1, 2, 3) | expand array as uid | fields name, occupation, uid` ``` #### Correlation Commands: diff --git a/docs/ppl-lang/ppl-expand-command.md b/docs/ppl-lang/ppl-expand-command.md index 0a52b4eb9..1e9fc319f 100644 --- a/docs/ppl-lang/ppl-expand-command.md +++ b/docs/ppl-lang/ppl-expand-command.md @@ -39,15 +39,6 @@ This example shows how to expand an array of struct field. PPL query: - `source=table | expand bridges as britishBridge | fields britishBridge` -| \_time | bridges | city | country | alt | lat | long | -|---------------------|----------------------------------------------|---------|---------------|-----|--------|--------| -| 2024-09-13T12:00:00 | [{801, Tower Bridge}, {928, London Bridge}] | London | England | 35 | 51.5074| -0.1278| -| 2024-09-13T12:00:00 | [{232, Pont Neuf}, {160, Pont Alexandre III}]| Paris | France | 35 | 48.8566| 2.3522 | -| 2024-09-13T12:00:00 | [{48, Rialto Bridge}, {11, Bridge of Sighs}] | Venice | Italy | 2 | 45.4408| 12.3155| -| 2024-09-13T12:00:00 | [{516, Charles Bridge}, {343, Legion Bridge}]| Prague | Czech Republic| 200 | 50.0755| 14.4378| -| 2024-09-13T12:00:00 | [{375, Chain Bridge}, {333, Liberty Bridge}] | Budapest| Hungary | 96 | 47.4979| 19.0402| -| 1990-09-13T12:00:00 | NULL | Warsaw | Poland | NULL| NULL | NULL | - ### Example 2: expand array diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala index 11be46756..a2b780c59 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala @@ -23,6 +23,7 @@ class FlintSparkPPLExpandITSuite with StreamTest { private val testTable = "flint_ppl_test" + private val occupationTable = "spark_catalog.default.flint_ppl_flat_table_test" private val structNestedTable = "spark_catalog.default.flint_ppl_struct_nested_test" private val structTable = "spark_catalog.default.flint_ppl_struct_test" private val multiValueTable = "spark_catalog.default.flint_ppl_multi_value_test" @@ -33,6 +34,7 @@ class FlintSparkPPLExpandITSuite // Create test table createNestedJsonContentTable(tempFile, testTable) + createOccupationTable(occupationTable) createStructNestedTable(structNestedTable) createStructTable(structTable) createMultiValueStructTable(multiValueTable) @@ -52,7 +54,61 @@ class FlintSparkPPLExpandITSuite Files.deleteIfExists(tempFile) } - test("flatten for structs") { + test("expand for eval field of an array") { + val frame = sql( + s""" source = $occupationTable | eval array=json_array(1, 2, 3) | expand array as uid | fields name, occupation, uid + """.stripMargin) + + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = Array( + Row("Jake", "Engineer", 1), + Row("Jake", "Engineer", 2), + Row("Jake", "Engineer", 3), + Row("Hello", "Artist", 1), + Row("Hello", "Artist", 2), + Row("Hello", "Artist", 3), + Row("John", "Doctor", 1), + Row("John", "Doctor", 2), + Row("John", "Doctor", 3), + Row("David", "Doctor", 1), + Row("David", "Doctor", 2), + Row("David", "Doctor", 3), + Row("David", "Unemployed", 1), + Row("David", "Unemployed", 2), + Row("David", "Unemployed", 3), + Row("Jane", "Scientist", 1), + Row("Jane", "Scientist", 2), + Row("Jane", "Scientist", 3)) + + // Compare the results + assert(results.toSet == expectedResults.toSet) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // expected plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_flat_table_test")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "array")() + val project = Project(seq(UnresolvedStar(None), aliasA), table) + val generate = Generate( + Explode(UnresolvedAttribute("array")), + seq(), + false, + None, + seq(UnresolvedAttribute("uid")), + project) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("array")), generate) + val expectedPlan = Project( + seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("occupation"), + UnresolvedAttribute("uid")), + dropSourceColumn) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("expand for structs") { val frame = sql( s""" source = $multiValueTable | expand multi_value AS exploded_multi_value | fields exploded_multi_value """.stripMargin) @@ -73,11 +129,10 @@ class FlintSparkPPLExpandITSuite val logicalPlan: LogicalPlan = frame.queryExecution.logical // expected plan val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_value_test")) - val outerGenerator = GeneratorOuter(Explode(UnresolvedAttribute("multi_value"))) val generate = Generate( - outerGenerator, + Explode(UnresolvedAttribute("multi_value")), seq(), - outer = true, + outer = false, None, seq(UnresolvedAttribute("exploded_multi_value")), table) @@ -98,8 +153,10 @@ class FlintSparkPPLExpandITSuite val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array( Row(mutable.WrappedArray.make(Array(Row(801, "Tower Bridge"), Row(928, "London Bridge")))), - Row(mutable.WrappedArray.make(Array(Row(801, "Tower Bridge"), Row(928, "London Bridge")))), - Row(null)) + Row(mutable.WrappedArray.make(Array(Row(801, "Tower Bridge"), Row(928, "London Bridge")))) + // Row(null)) -> in case of outerGenerator = GeneratorOuter(Explode(UnresolvedAttribute("bridges"))) it will include the `null` row + ) + // Compare the results assert(results.toSet == expectedResults.toSet) val logicalPlan: LogicalPlan = frame.queryExecution.logical @@ -109,8 +166,8 @@ class FlintSparkPPLExpandITSuite EqualTo(UnresolvedAttribute("country"), Literal("England")), EqualTo(UnresolvedAttribute("country"), Literal("Poland"))), table) - val outerGenerator = GeneratorOuter(Explode(UnresolvedAttribute("bridges"))) - val generate = Generate(outerGenerator, seq(), outer = true, None, seq(), filter) + val generate = + Generate(Explode(UnresolvedAttribute("bridges")), seq(), outer = false, None, seq(), filter) val expectedPlan = Project(Seq(UnresolvedAttribute("bridges")), generate) comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } @@ -134,11 +191,10 @@ class FlintSparkPPLExpandITSuite val logicalPlan: LogicalPlan = frame.queryExecution.logical val table = UnresolvedRelation(Seq("flint_ppl_test")) val filter = Filter(EqualTo(UnresolvedAttribute("country"), Literal("England")), table) - val outerGenerator = GeneratorOuter(Explode(UnresolvedAttribute("bridges"))) val generate = Generate( - outerGenerator, + Explode(UnresolvedAttribute("bridges")), seq(), - outer = true, + outer = false, None, seq(UnresolvedAttribute("britishBridges")), filter) @@ -166,48 +222,10 @@ class FlintSparkPPLExpandITSuite val logicalPlan: LogicalPlan = frame.queryExecution.logical val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_struct_test")) - val outerGenerator = GeneratorOuter(Explode(UnresolvedAttribute("bridges"))) - val generate = Generate( - outerGenerator, - seq(), - outer = true, - None, - seq(UnresolvedAttribute("britishBridges")), - table) - val expectedPlan = Project(Seq(UnresolvedStar(None)), generate) - comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) - } - - ignore("expand multi value nullable") { - val frame = sql(s""" - | source = $multiValueTable - | | expand multi_value as expand_field - | | fields expand_field - | """.stripMargin) - - assert(frame.columns.sameElements(Array("expand_field"))) - val results: Array[Row] = frame.collect() - val expectedResults: Array[Row] = - Array( - Row(1, "1_one", 1), - Row(1, null, 11), - Row(1, "1_three", null), - Row(2, "2_Monday", 2), - Row(2, null, null), - Row(3, "3_third", 3), - Row(3, "3_4th", 4), - Row(4, null, null)) - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - val logicalPlan: LogicalPlan = frame.queryExecution.logical - val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_value_test")) - val outerGenerator = GeneratorOuter(Explode(UnresolvedAttribute("bridges"))) val generate = Generate( - outerGenerator, + Explode(UnresolvedAttribute("bridges")), seq(), - outer = true, + outer = false, None, seq(UnresolvedAttribute("britishBridges")), table) @@ -241,15 +259,15 @@ class FlintSparkPPLExpandITSuite val logicalPlan: LogicalPlan = frame.queryExecution.logical val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_struct_nested_test")) -// val flattenStructCol = generator("struct_col", table) -// val flattenField1 = generator("field1", flattenStructCol) -// val flattenStructCol2 = generator("struct_col2", flattenField1) -// val flattenField1Again = generator("field1", flattenStructCol2) +// val expandStructCol = generator("struct_col", table) +// val expandField1 = generator("field1", expandStructCol) +// val expandStructCol2 = generator("struct_col2", expandField1) +// val expandField1Again = generator("field1", expandStructCol2) val expectedPlan = Project(Seq(UnresolvedStar(None)), table) comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } - ignore("flatten multi value nullable") { + ignore("expand multi value nullable") { val frame = sql(s""" | source = $multiValueTable | | expand multi_value as expand_field @@ -274,11 +292,10 @@ class FlintSparkPPLExpandITSuite val logicalPlan: LogicalPlan = frame.queryExecution.logical val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_value_test")) - val outerGenerator = GeneratorOuter(Explode(UnresolvedAttribute("bridges"))) val generate = Generate( - outerGenerator, + Explode(UnresolvedAttribute("bridges")), seq(), - outer = true, + outer = false, None, seq(UnresolvedAttribute("britishBridges")), table) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index a5c68574f..e4df7b16d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -479,12 +479,12 @@ public LogicalPlan visitExpand(org.opensearch.sql.ast.tree.Expand node, Catalyst Optional alias = node.getAlias().map(aliasNode -> visitExpression(aliasNode, context)); context.retainAllNamedParseExpressions(p -> (NamedExpression) p); Explode explodeGenerator = new Explode(field); - scala.collection.mutable.Seq seq = alias.isEmpty() ? seq() : seq(alias.get()); + scala.collection.mutable.Seq outputs = alias.isEmpty() ? seq() : seq(alias.get()); if(alias.isEmpty()) - return context.apply(p -> new Generate(new GeneratorOuter(explodeGenerator), seq(), true, (Option) None$.MODULE$, seq, p)); + return context.apply(p -> new Generate(explodeGenerator, seq(), false, (Option) None$.MODULE$, outputs, p)); else { //in case an alias does appear - remove the original field from the returning columns - context.apply(p -> new Generate(new GeneratorOuter(explodeGenerator), seq(), true, (Option) None$.MODULE$, seq, p)); + context.apply(p -> new Generate(explodeGenerator, seq(), false, (Option) None$.MODULE$, outputs, p)); return context.apply(logicalPlan -> DataFrameDropColumns$.MODULE$.apply(seq(field), logicalPlan)); } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala index 92d0d521e..5f15c93e4 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala @@ -34,12 +34,38 @@ class PPLLogicalPlanExpandCommandTranslatorTestSuite val relation = UnresolvedRelation(Seq("relation")) val generator = Explode(UnresolvedAttribute("field_with_array")) - val outerGenerator = GeneratorOuter(generator) - val generate = Generate(outerGenerator, seq(), true, None, seq(), relation) + val generate = Generate(generator, seq(), false, None, seq(), relation) val expectedPlan = Project(seq(UnresolvedStar(None)), generate) comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + test("test expand on array field which is eval array=json_array") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source = table | eval array=json_array(1, 2, 3) | expand array as uid | fields uid"), + context) + + val relation = UnresolvedRelation(Seq("table")) + val jsonFunc = + UnresolvedFunction("array", Seq(Literal(1), Literal(2), Literal(3)), isDistinct = false) + val aliasA = Alias(jsonFunc, "array")() + val project = Project(seq(UnresolvedStar(None), aliasA), relation) + val generate = Generate( + Explode(UnresolvedAttribute("array")), + seq(), + false, + None, + seq(UnresolvedAttribute("uid")), + project) + val dropSourceColumn = + DataFrameDropColumns(Seq(UnresolvedAttribute("array")), generate) + val expectedPlan = Project(seq(UnresolvedAttribute("uid")), dropSourceColumn) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + test("test expand only field with alias") { val context = new CatalystPlanContext val logPlan = @@ -48,12 +74,10 @@ class PPLLogicalPlanExpandCommandTranslatorTestSuite context) val relation = UnresolvedRelation(Seq("relation")) - val generator = Explode(UnresolvedAttribute("field_with_array")) - val outerGenerator = GeneratorOuter(generator) val generate = Generate( - outerGenerator, + Explode(UnresolvedAttribute("field_with_array")), seq(), - true, + false, None, seq(UnresolvedAttribute("array_list")), relation) @@ -70,13 +94,8 @@ class PPLLogicalPlanExpandCommandTranslatorTestSuite val logPlan = planTransformer.visit(plan(pplParser, query), context) val table = UnresolvedRelation(Seq("table")) - val generate = Generate( - GeneratorOuter(Explode(UnresolvedAttribute("employee"))), - seq(), - true, - None, - seq(), - table) + val generate = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) val average = Alias( UnresolvedFunction(seq("MAX"), seq(UnresolvedAttribute("salary")), false, None, false), "max")() @@ -99,14 +118,13 @@ class PPLLogicalPlanExpandCommandTranslatorTestSuite planTransformer.visit(plan(pplParser, query), context) val table = UnresolvedRelation(Seq("table")) val generate = Generate( - GeneratorOuter(Explode(UnresolvedAttribute("employee"))), + Explode(UnresolvedAttribute("employee")), seq(), - true, + false, None, seq(UnresolvedAttribute("workers")), table) val dropSourceColumn = DataFrameDropColumns(Seq(UnresolvedAttribute("employee")), generate) - val dropColumn = Project(seq(UnresolvedStar(None)), dropSourceColumn) val average = Alias( UnresolvedFunction(seq("MAX"), seq(UnresolvedAttribute("salary")), false, None, false), "max")() @@ -128,13 +146,8 @@ class PPLLogicalPlanExpandCommandTranslatorTestSuite val query = "source = table | expand employee | eval bonus = salary * 3" val logPlan = planTransformer.visit(plan(pplParser, query), context) val table = UnresolvedRelation(Seq("table")) - val generate = Generate( - GeneratorOuter(Explode(UnresolvedAttribute("employee"))), - seq(), - true, - None, - seq(), - table) + val generate = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) val bonusProject = Project( Seq( UnresolvedStar(None), @@ -156,9 +169,9 @@ class PPLLogicalPlanExpandCommandTranslatorTestSuite val logPlan = planTransformer.visit(plan(pplParser, query), context) val table = UnresolvedRelation(Seq("table")) val generate = Generate( - GeneratorOuter(Explode(UnresolvedAttribute("employee"))), + Explode(UnresolvedAttribute("employee")), seq(), - true, + false, None, seq(UnresolvedAttribute("worker")), table) @@ -188,13 +201,8 @@ class PPLLogicalPlanExpandCommandTranslatorTestSuite "source=table | expand employee | parse description '(?.+@.+)' | fields employee, email"), context) val table = UnresolvedRelation(Seq("table")) - val generator = Generate( - GeneratorOuter(Explode(UnresolvedAttribute("employee"))), - seq(), - true, - None, - seq(), - table) + val generator = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) val emailAlias = Alias( RegExpExtract(UnresolvedAttribute("description"), Literal("(?.+@.+)"), Literal(1)), @@ -216,13 +224,8 @@ class PPLLogicalPlanExpandCommandTranslatorTestSuite "source=relation | expand employee | parse description '(?.+@.+)' | flatten roles "), context) val table = UnresolvedRelation(Seq("relation")) - val generateEmployee = Generate( - GeneratorOuter(Explode(UnresolvedAttribute("employee"))), - seq(), - true, - None, - seq(), - table) + val generateEmployee = + Generate(Explode(UnresolvedAttribute("employee")), seq(), false, None, seq(), table) val emailAlias = Alias( RegExpExtract(UnresolvedAttribute("description"), Literal("(?.+@.+)"), Literal(1)),