diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 5d4f68cb6..4a70ff610 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -457,6 +457,7 @@ _- **Limitation: another command usage of (relation) subquery is in `appendcols` - `source = table | expand employee as worker | eval bonus = salary * 3 | fields worker, bonus` - `source = table | expand employee | parse description '(?.+@.+)' | fields employee, email` - `source = table | eval array=json_array(1, 2, 3) | expand array as uid | fields name, occupation, uid` + - `source = table | expand multi_valueA as multiA | expand multi_valueB as multiB` ``` #### Correlation Commands: diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index c53eee548..68d370791 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -559,6 +559,28 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit |""".stripMargin) } + protected def createMultiColumnArrayTable(testTable: String): Unit = { + // CSV doesn't support struct field + sql(s""" + | CREATE TABLE $testTable + | ( + | int_col INT, + | multi_valueA Array>, + | multi_valueB Array> + | ) + | USING JSON + |""".stripMargin) + + sql(s""" + | INSERT INTO $testTable + | VALUES + | ( 1, array(STRUCT("1_one", 1), STRUCT(null, 11), STRUCT("1_three", null)), array(STRUCT("2_Monday", 2), null) ), + | ( 2, array(STRUCT("2_Monday", 2), null) , array(STRUCT("3_third", 3), STRUCT("3_4th", 4)) ), + | ( 3, array(STRUCT("3_third", 3), STRUCT("3_4th", 4)) , array(STRUCT("1_one", 1))), + | ( 4, null, array(STRUCT("1_one", 1))) + |""".stripMargin) + } + protected def createTableIssue112(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable ( diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala index a2b780c59..f0404bf7b 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLExpandITSuite.scala @@ -27,6 +27,7 @@ class FlintSparkPPLExpandITSuite private val structNestedTable = "spark_catalog.default.flint_ppl_struct_nested_test" private val structTable = "spark_catalog.default.flint_ppl_struct_test" private val multiValueTable = "spark_catalog.default.flint_ppl_multi_value_test" + private val multiArraysTable = "spark_catalog.default.flint_ppl_multi_array_test" private val tempFile = Files.createTempFile("jsonTestData", ".json") override def beforeAll(): Unit = { @@ -38,6 +39,7 @@ class FlintSparkPPLExpandITSuite createStructNestedTable(structNestedTable) createStructTable(structTable) createMultiValueStructTable(multiValueTable) + createMultiColumnArrayTable(multiArraysTable) } protected override def afterEach(): Unit = { @@ -205,101 +207,49 @@ class FlintSparkPPLExpandITSuite comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } - ignore("expand struct table") { + test("expand multi columns array table") { val frame = sql(s""" - | source = $structTable - | | expand struct_col - | | expand field1 - | """.stripMargin) - - assert(frame.columns.sameElements(Array("int_col", "field2", "subfield"))) - val results: Array[Row] = frame.collect() - val expectedResults: Array[Row] = - Array(Row(30, 123, "value1"), Row(40, 456, "value2"), Row(50, 789, "value3")) - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - val logicalPlan: LogicalPlan = frame.queryExecution.logical - val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_struct_test")) - val generate = Generate( - Explode(UnresolvedAttribute("bridges")), - seq(), - outer = false, - None, - seq(UnresolvedAttribute("britishBridges")), - table) - val expectedPlan = Project(Seq(UnresolvedStar(None)), generate) - comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) - } - - ignore("expand struct nested table") { - val frame = sql(s""" - | source = $structNestedTable - | | expand struct_col - | | expand field1 - | | expand struct_col2 - | | expand field1 - | """.stripMargin) - - assert( - frame.columns.sameElements(Array("int_col", "field2", "subfield", "field2", "subfield"))) - val results: Array[Row] = frame.collect() - val expectedResults: Array[Row] = - Array( - Row(30, 123, "value1", 23, "valueA"), - Row(40, 123, "value5", 33, "valueB"), - Row(30, 823, "value4", 83, "valueC"), - Row(40, 456, "value2", 46, "valueD"), - Row(50, 789, "value3", 89, "valueE")) - // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) - - val logicalPlan: LogicalPlan = frame.queryExecution.logical - val table = - UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_struct_nested_test")) -// val expandStructCol = generator("struct_col", table) -// val expandField1 = generator("field1", expandStructCol) -// val expandStructCol2 = generator("struct_col2", expandField1) -// val expandField1Again = generator("field1", expandStructCol2) - val expectedPlan = Project(Seq(UnresolvedStar(None)), table) - comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) - } - - ignore("expand multi value nullable") { - val frame = sql(s""" - | source = $multiValueTable - | | expand multi_value as expand_field - | | fields expand_field + | source = $multiArraysTable + | | expand multi_valueA as multiA + | | expand multi_valueB as multiB | """.stripMargin) - assert(frame.columns.sameElements(Array("expand_field"))) val results: Array[Row] = frame.collect() - val expectedResults: Array[Row] = - Array( - Row(1, "1_one", 1), - Row(1, null, 11), - Row(1, "1_three", null), - Row(2, "2_Monday", 2), - Row(2, null, null), - Row(3, "3_third", 3), - Row(3, "3_4th", 4), - Row(4, null, null)) + val expectedResults: Array[Row] = Array( + Row(1, Row("1_one", 1), Row("2_Monday", 2)), + Row(1, Row("1_one", 1), null), + Row(1, Row(null, 11), Row("2_Monday", 2)), + Row(1, Row(null, 11), null), + Row(1, Row("1_three", null), Row("2_Monday", 2)), + Row(1, Row("1_three", null), null), + Row(2, Row("2_Monday", 2), Row("3_third", 3)), + Row(2, Row("2_Monday", 2), Row("3_4th", 4)), + Row(2, null, Row("3_third", 3)), + Row(2, null, Row("3_4th", 4)), + Row(3, Row("3_third", 3), Row("1_one", 1)), + Row(3, Row("3_4th", 4), Row("1_one", 1))) // Compare the results - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) - assert(results.sorted.sameElements(expectedResults.sorted)) + assert(results.toSet == expectedResults.toSet) val logicalPlan: LogicalPlan = frame.queryExecution.logical - val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_value_test")) - val generate = Generate( - Explode(UnresolvedAttribute("bridges")), + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_multi_array_test")) + val generatorA = Explode(UnresolvedAttribute("multi_valueA")) + val generateA = + Generate(generatorA, seq(), false, None, seq(UnresolvedAttribute("multiA")), table) + val dropSourceColumnA = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueA")), generateA) + val generatorB = Explode(UnresolvedAttribute("multi_valueB")) + val generateB = Generate( + generatorB, seq(), - outer = false, + false, None, - seq(UnresolvedAttribute("britishBridges")), - table) - val expectedPlan = Project(Seq(UnresolvedStar(None)), generate) - comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + seq(UnresolvedAttribute("multiB")), + dropSourceColumnA) + val dropSourceColumnB = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueB")), generateB) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumnB) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala index 5f15c93e4..2acaac529 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanExpandCommandTranslatorTestSuite.scala @@ -39,6 +39,38 @@ class PPLLogicalPlanExpandCommandTranslatorTestSuite comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + test("expand multi columns array table") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + s""" + | source = table + | | expand multi_valueA as multiA + | | expand multi_valueB as multiB + | """.stripMargin), + context) + + val relation = UnresolvedRelation(Seq("table")) + val generatorA = Explode(UnresolvedAttribute("multi_valueA")) + val generateA = + Generate(generatorA, seq(), false, None, seq(UnresolvedAttribute("multiA")), relation) + val dropSourceColumnA = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueA")), generateA) + val generatorB = Explode(UnresolvedAttribute("multi_valueB")) + val generateB = Generate( + generatorB, + seq(), + false, + None, + seq(UnresolvedAttribute("multiB")), + dropSourceColumnA) + val dropSourceColumnB = + DataFrameDropColumns(Seq(UnresolvedAttribute("multi_valueB")), generateB) + val expectedPlan = Project(seq(UnresolvedStar(None)), dropSourceColumnB) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + test("test expand on array field which is eval array=json_array") { val context = new CatalystPlanContext val logPlan =