diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLProjectStatementITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLProjectStatementITSuite.scala index 85fa5cf40..54d4aff1b 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLProjectStatementITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLProjectStatementITSuite.scala @@ -10,7 +10,8 @@ import java.nio.file.{Files, Paths} import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedIdentifier, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, EqualTo, IsNotNull, Literal, Not, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, IsNotNull, Literal, Multiply, Not, Or, SortOrder} +import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.connector.expressions.{FieldReference, IdentityTransform} import org.apache.spark.sql.execution.ExplainMode @@ -23,12 +24,15 @@ class FlintSparkPPLProjectStatementITSuite with FlintPPLSuite with StreamTest { - /** Test table and index name */ + /** Test table */ private val testTable = "spark_catalog.default.flint_ppl_test" - private val t1 = "`spark_catalog`.`default`.`flint_ppl_test1`" - private val t2 = "`spark_catalog`.default.`flint_ppl_test2`" - private val t3 = "spark_catalog.`default`.`flint_ppl_test3`" - private val t4 = "`spark_catalog`.`default`.flint_ppl_test4" + private val testTable1 = "spark_catalog.default.flint_ppl_test1" + private val testTable2 = "spark_catalog.default.flint_ppl_test2" + + private val t1 = "`spark_catalog`.`default`.`flint_ppl_t1`" + private val t2 = "`spark_catalog`.default.`flint_ppl_t2`" + + /* view projection */ private val viewName = "simpleView" // location of the projected view private val viewFolderLocation = Paths.get(".", "spark-warehouse", "student_partition_bucket") @@ -37,11 +41,22 @@ class FlintSparkPPLProjectStatementITSuite super.beforeAll() // Create test table + createPartitionedStateCountryTable(testTable1) + // Update data insertion + sql(s""" + | INSERT INTO $testTable1 + | PARTITION (year=2023, month=4) + | VALUES ('Jim', 27, 'B.C', 'Canada'), + | ('Peter', 57, 'B.C', 'Canada'), + | ('Rick', 70, 'B.C', 'Canada'), + | ('David', 40, 'Washington', 'USA') + | """.stripMargin) + + createOccupationTable(testTable2) + // none join tables createPartitionedStateCountryTable(testTable) createPartitionedStateCountryTable(t1) createPartitionedStateCountryTable(t2) - createPartitionedStateCountryTable(t3) - createPartitionedStateCountryTable(t4) } protected override def afterEach(): Unit = { @@ -542,4 +557,105 @@ class FlintSparkPPLProjectStatementITSuite assert(compareByString(logicalPlan) == expectedPlan.toString) } + test("test inner join with relation subquery") { + val viewLocation = viewFolderLocation.toAbsolutePath.toString + val frame = sql(s""" + | project $viewName using parquet OPTIONS('parquet.bloom.filter.enabled'='true') + | partitioned by (age_span) location '$viewLocation' + | | source = $testTable1 + | | where country = 'USA' OR country = 'England' + | | inner join left=a, right=b + | ON a.name = b.name + | [ + | source = $testTable2 + | | where salary > 0 + | | fields name, country, salary + | | sort salary + | | head 3 + | ] + | | stats avg(salary) by span(age, 10) as age_span, b.country + | """.stripMargin) + // verify new view was created correctly + frame.collect() + val results = sql(s""" + | source = $viewName + | """.stripMargin).collect() + + val expectedResults: Array[Row] = Array(Row(70000.0, "USA", 30), Row(100000.0, "England", 70)) + + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val filterExpr = Or( + EqualTo(UnresolvedAttribute("country"), Literal("USA")), + EqualTo(UnresolvedAttribute("country"), Literal("England"))) + val plan1 = SubqueryAlias("a", Filter(filterExpr, table1)) + val rightSubquery = + GlobalLimit( + Literal(3), + LocalLimit( + Literal(3), + Sort( + Seq(SortOrder(UnresolvedAttribute("salary"), Ascending)), + global = true, + Project( + Seq( + UnresolvedAttribute("name"), + UnresolvedAttribute("country"), + UnresolvedAttribute("salary")), + Filter(GreaterThan(UnresolvedAttribute("salary"), Literal(0)), table2))))) + val plan2 = SubqueryAlias("b", rightSubquery) + + val joinCondition = EqualTo(UnresolvedAttribute("a.name"), UnresolvedAttribute("b.name")) + val joinPlan = Join(plan1, plan2, Inner, Some(joinCondition), JoinHint.NONE) + + val salaryField = UnresolvedAttribute("salary") + val countryField = UnresolvedAttribute("b.country") + val countryAlias = Alias(countryField, "b.country")() + val star = Seq(UnresolvedStar(None)) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(salaryField), isDistinct = false), "avg(salary)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = + Aggregate(Seq(countryAlias, span), Seq(aggregateExpressions, countryAlias, span), joinPlan) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val expectedPlan: LogicalPlan = + CreateTableAsSelect( + UnresolvedIdentifier(Seq("simpleView")), + Seq(), + // Seq(IdentityTransform.apply(FieldReference.apply("age_span"))), + Project(star, aggregatePlan), + UnresolvedTableSpec( + Map.empty, + Option("PARQUET"), + OptionList(Seq(("parquet.bloom.filter.enabled", Literal("true")))), + Option(viewLocation), + Option.empty, + Option.empty, + external = false), + Map.empty, + ignoreIfExists = true, + isAnalyzed = false) + + // Compare the two plans + comparePlans( + logicalPlan.asInstanceOf[CreateTableAsSelect].query, + expectedPlan.asInstanceOf[CreateTableAsSelect].query, + checkAnalysis = false) + comparePlans( + logicalPlan.asInstanceOf[CreateTableAsSelect].name, + expectedPlan.asInstanceOf[CreateTableAsSelect].name, + checkAnalysis = false) + assert( + logicalPlan.asInstanceOf[CreateTableAsSelect].tableSpec.toString == expectedPlan + .asInstanceOf[CreateTableAsSelect] + .tableSpec + .toString) + } + }