From 076ae342d82cabb20281afe3039b91e5519244fd Mon Sep 17 00:00:00 2001 From: YANGDB Date: Mon, 21 Oct 2024 12:15:49 -0700 Subject: [PATCH] update documentation with tablesample(50 percent) option Signed-off-by: YANGDB --- docs/ppl-lang/PPL-Example-Commands.md | 33 +- .../{planning => }/ppl-fillnull-command.md | 0 docs/ppl-lang/ppl-rare-command.md | 16 + docs/ppl-lang/ppl-search-command.md | 37 +++ docs/ppl-lang/ppl-top-command.md | 19 ++ .../ppl/FlintSparkPPLTopAndRareITSuite.scala | 39 ++- .../sql/ppl/utils/RelationUtils.java | 2 +- ...ggregationQueriesTranslatorTestSuite.scala | 288 ++++++++++++++++++ ...orrelationQueriesTranslatorTestSuite.scala | 2 +- ...PLLogicalPlanEvalTranslatorTestSuite.scala | 22 +- ...calPlanInSubqueryTranslatorTestSuite.scala | 102 ++++++- ...PLLogicalPlanJoinTranslatorTestSuite.scala | 59 +++- ...PlanNestedQueriesTranslatorTestSuite.scala | 23 ++ ...lanScalarSubqueryTranslatorTestSuite.scala | 118 +++++++ ...TopAndRareQueriesTranslatorTestSuite.scala | 85 +++++- 15 files changed, 813 insertions(+), 32 deletions(-) rename docs/ppl-lang/{planning => }/ppl-fillnull-command.md (100%) diff --git a/docs/ppl-lang/PPL-Example-Commands.md b/docs/ppl-lang/PPL-Example-Commands.md index 96eeef726..c50056638 100644 --- a/docs/ppl-lang/PPL-Example-Commands.md +++ b/docs/ppl-lang/PPL-Example-Commands.md @@ -136,6 +136,7 @@ source = table | where ispresent(a) | [See additional command details](ppl-stats-command.md) - `source = table | stats avg(a) ` +- `source = table tablesample(50 percent) | stats avg(a) ` - `source = table | where a < 50 | stats avg(c) ` - `source = table | stats max(c) by b` - `source = table | stats count(c) by b | head 5` @@ -148,6 +149,7 @@ source = table | where ispresent(a) | **Aggregations With Span** - `source = table | stats count(a) by span(a, 10) as a_span` - `source = table | stats sum(age) by span(age, 5) as age_span | head 2` +- `source = table tablesample(50 percent) | stats sum(age) by span(age, 5) as age_span | head 2` - `source = table | stats avg(age) by span(age, 20) as age_span, country | sort - age_span | head 2` **Aggregations With TimeWindow Span (tumble windowing function)** @@ -181,6 +183,7 @@ source = table | where ispresent(a) | - `source=accounts | rare gender` - `source=accounts | rare age by gender` +- `source=accounts tablesample(50 percent) | rare age by gender` #### **Top** [See additional command details](ppl-top-command.md) @@ -188,6 +191,7 @@ source = table | where ispresent(a) | - `source=accounts | top gender` - `source=accounts | top 1 gender` - `source=accounts | top 1 age by gender` +- `source=accounts tablesample(50 percent) | top 1 age by gender` #### **Parse** [See additional command details](ppl-parse-command.md) @@ -234,6 +238,9 @@ source = table | where ispresent(a) | [See additional command details](ppl-join-command.md) - `source = table1 | inner join left = l right = r on l.a = r.a table2 | fields l.a, r.a, b, c` +- `source = table1 tablesample(50 percent) | inner join left = l right = r on l.a = r.a table2 | fields l.a, r.a, b, c` +- `source = table1 | inner join left = l right = r on l.a = r.a table2 tablesample(50 percent) | fields l.a, r.a, b, c` +- `source = table1 tablesample(50 percent) | inner join left = l right = r on l.a = r.a table2 tablesample(50 percent) | fields l.a, r.a, b, c` - `source = table1 | left join left = l right = r on l.a = r.a table2 | fields l.a, r.a, b, c` - `source = table1 | right join left = l right = r on l.a = r.a table2 | fields l.a, r.a, b, c` - `source = table1 | full left = l right = r on l.a = r.a table2 | fields l.a, r.a, b, c` @@ -262,11 +269,14 @@ _- **Limitation: "REPLACE" or "APPEND" clause must contain "AS"**_ [See additional command details](ppl-subquery-command.md) - `source = outer | where a in [ source = inner | fields b ]` +- `source = outer tablesample(50 percent) | where a in [ source = inner | fields b ]` - `source = outer | where (a) in [ source = inner | fields b ]` +- `source = outer | where (a) in [ source = inner tablesample(50 percent) | fields b ]` - `source = outer | where (a,b,c) in [ source = inner | fields d,e,f ]` - `source = outer | where a not in [ source = inner | fields b ]` - `source = outer | where (a) not in [ source = inner | fields b ]` - `source = outer | where (a,b,c) not in [ source = inner | fields d,e,f ]` +- `source = outer tablesample(50 percent) | where (a,b,c) not in [ source = inner tablesample(50 percent) | fields d,e,f ]` - `source = outer a in [ source = inner | fields b ]` (search filtering with subquery) - `source = outer a not in [ source = inner | fields b ]` (search filtering with subquery) - `source = outer | where a in [ source = inner1 | where b not in [ source = inner2 | fields c ] | fields b ]` (nested) @@ -368,10 +378,22 @@ Assumptions: `a`, `b` are fields of table outer, `c`, `d` are fields of table in `InSubquery`, `ExistsSubquery` and `ScalarSubquery` are all subquery expressions. But `RelationSubquery` is not a subquery expression, it is a subquery plan which is common used in Join or Search clause. - `source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` (subquery in join right side) +- `source = table1 tablesample(50 percent) | join left = l right = r [ source = table2 | where d > 10 | head 5 ]` (subquery in join right side) - `source = [ source = table1 | join left = l right = r [ source = table2 | where d > 10 | head 5 ] | stats count(a) by b ] as outer | head 1` _- **Limitation: another command usage of (relation) subquery is in `appendcols` commands which is unsupported**_ +#### **fillnull** +[See additional command details](ppl-fillnull-command.md) + +```sql + - `source=accounts | fillnull fields status_code=101` + - `source=accounts | fillnull fields request_path='/not_found', timestamp='*'` + - `source=accounts | fillnull using field1=101` + - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5` + - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5, field6 = 'N/A'` +``` + --- #### Experimental Commands: [See additional command details](ppl-correlation-command.md) @@ -385,15 +407,4 @@ _- **Limitation: another command usage of (relation) subquery is in `appendcols` > ppl-correlation-command is an experimental command - it may be removed in future versions --- -### Planned Commands: - -#### **fillnull** - -```sql - - `source=accounts | fillnull fields status_code=101` - - `source=accounts | fillnull fields request_path='/not_found', timestamp='*'` - - `source=accounts | fillnull using field1=101` - - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5` - - `source=accounts | fillnull using field1=concat(field2, field3), field4=2*pi()*field5, field6 = 'N/A'` -``` [See additional command details](planning/ppl-fillnull-command.md) diff --git a/docs/ppl-lang/planning/ppl-fillnull-command.md b/docs/ppl-lang/ppl-fillnull-command.md similarity index 100% rename from docs/ppl-lang/planning/ppl-fillnull-command.md rename to docs/ppl-lang/ppl-fillnull-command.md diff --git a/docs/ppl-lang/ppl-rare-command.md b/docs/ppl-lang/ppl-rare-command.md index 5645382f8..3b25bd1db 100644 --- a/docs/ppl-lang/ppl-rare-command.md +++ b/docs/ppl-lang/ppl-rare-command.md @@ -44,3 +44,19 @@ PPL query: | M | 33 | | M | 36 | +----------+-------+ + +### Example 3: Find the rare address using only 50% of the actual data (sampling) + +PPL query: + + os> source = accounts TABLESAMPLE(50 percent) | rare address + +The logical plan outcome of the rare queries: + +```sql +'Sort ['COUNT('address) AS count_address#91 ASC NULLS FIRST], true ++- 'Aggregate ['address], ['COUNT('address) AS count_address#90, 'address] + +- 'Sample 0.0, 0.5, false, 0 + +- 'UnresolvedRelation [accounts], [], false + +``` \ No newline at end of file diff --git a/docs/ppl-lang/ppl-search-command.md b/docs/ppl-lang/ppl-search-command.md index bccfd04f0..d6ca3aa92 100644 --- a/docs/ppl-lang/ppl-search-command.md +++ b/docs/ppl-lang/ppl-search-command.md @@ -40,3 +40,40 @@ PPL query: | 13 | Nanette | 789 Madison Street | 32838 | F | Nogal | Quility | VA | 28 | null | Bates | +------------------+-------------+--------------------+-----------+----------+--------+------------+---------+-------+----------------------+------------+ +### Example 3: Fetch data with a sampling percentage ( including an aggregation) +The following example demonstrates how to sample 50% of the data from the table and then perform aggregation (finding rare occurrences of address). + +PPL query: + + os> source = account TABLESAMPLE(75 percent) | top 3 country by occupation + +This query samples 75% of the records from account table, then retrieves the top 3 countries grouped by occupation + +```sql +SELECT * +FROM ( + SELECT country, occupation, COUNT(country) AS count_country + FROM account + TABLESAMPLE(75 PERCENT) + GROUP BY country, occupation + ORDER BY COUNT(country) DESC NULLS LAST + LIMIT 3 + ) AS subquery + LIMIT 3; +``` +Logical Plan Equivalent: + +```sql +'Project [*] ++- 'GlobalLimit 3 + +- 'LocalLimit 3 + +- 'Sort ['COUNT('country) AS count_country#68 DESC NULLS LAST], true + +- 'Aggregate ['country, 'occupation AS occupation#67], ['COUNT('country) AS count_country#66, 'country, 'occupation AS occupation#67] + +- 'Sample 0.0, 0.75, false, 0 + +- 'UnresolvedRelation [account], [], false + +``` + +By introducing the `TABLESAMPLE` instruction into the source command, one can now sample data as part of your queries and reducing the amount of data being scanned thereby converting precision with performance. + +The `percent` parameter will give the actual approximation of the true value with the needed trade of between accuracy and performance. \ No newline at end of file diff --git a/docs/ppl-lang/ppl-top-command.md b/docs/ppl-lang/ppl-top-command.md index 4ba56f692..3dae4dfcb 100644 --- a/docs/ppl-lang/ppl-top-command.md +++ b/docs/ppl-lang/ppl-top-command.md @@ -56,3 +56,22 @@ PPL query: | M | 32 | +----------+-------+ + +### Example 3: Find the top country by occupation using only 75% of the actual data (sampling) + +PPL query: + + os> source = account TABLESAMPLE(75 percent) | top 3 country by occupation + +The logical plan outcome of the top queries: + +```sql +'Project [*] ++- 'GlobalLimit 3 + +- 'LocalLimit 3 + +- 'Sort ['COUNT('country) AS count_country#68 DESC NULLS LAST], true + +- 'Aggregate ['country, 'occupation AS occupation#67], ['COUNT('country) AS count_country#66, 'country, 'occupation AS occupation#67] + +- 'Sample 0.0, 0.75, false, 0 + +- 'UnresolvedRelation [account], [], false + +``` \ No newline at end of file diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala index 7fb0f0b33..a96ddfbea 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala @@ -114,7 +114,7 @@ class FlintSparkPPLTopAndRareITSuite Seq(addressField), aggregateExpressions, Sample( - 0.5, + 0, 0.5, withReplacement = false, 0, @@ -274,18 +274,18 @@ class FlintSparkPPLTopAndRareITSuite val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } - - test("create ppl top 3 countries query test with tablesample 50%") { + + test("create ppl top 2 countries query test with tablesample 50%") { val frame = sql(s""" - | source = $newTestTable TABLESAMPLE(50 percent) | top 3 country + | source = $newTestTable TABLESAMPLE(50 percent) | top 2 country | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() - assert(results.length == 3) + assert(results.length == 1) - val expectedRows = Set(Row(6, "Canada"), Row(3, "USA"), Row(1, "England")) - val actualRows = results.take(3).toSet + val expectedRows = Set(Row(4, "Canada")) + val actualRows = results.take(1).toSet // Compare the sets assert( @@ -303,7 +303,12 @@ class FlintSparkPPLTopAndRareITSuite Aggregate( Seq(countryField), aggregateExpressions, - UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test"))) + Sample( + 0, + 0.5, + withReplacement = false, + 0, + UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test")))) val sortedPlan: LogicalPlan = Sort( @@ -317,12 +322,12 @@ class FlintSparkPPLTopAndRareITSuite aggregatePlan) val planWithLimit = - GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + GlobalLimit(Literal(2), LocalLimit(Literal(2), sortedPlan)) val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } - test("create ppl top 2 countries by occupation field query test") { + test("create ppl top 3 countries by occupation field query test") { val frame = sql(s""" | source = $newTestTable| top 3 country by occupation | """.stripMargin) @@ -373,9 +378,10 @@ class FlintSparkPPLTopAndRareITSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } - test("create ppl top 2 countries by occupation field query test with tablesample 50%") { + + test("create ppl top 3 countries by occupation field query test with tablesample 75%") { val frame = sql(s""" - | source = $newTestTable TABLESAMPLE(50 percent) | top 3 country by occupation + | source = $newTestTable TABLESAMPLE(75 percent) | top 3 country by occupation | """.stripMargin) // Retrieve the results @@ -383,7 +389,7 @@ class FlintSparkPPLTopAndRareITSuite assert(results.length == 3) val expectedRows = - Set(Row(3, "Canada", "Doctor"), Row(2, "Canada", "Scientist"), Row(2, "USA", "Engineer")) + Set(Row(2, "Canada", "Doctor"), Row(2, "Canada", "Scientist"), Row(1, "USA", "Engineer")) val actualRows = results.take(3).toSet // Compare the sets @@ -405,7 +411,12 @@ class FlintSparkPPLTopAndRareITSuite Aggregate( Seq(countryField, occupationFieldAlias), aggregateExpressions, - UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test"))) + Sample( + 0, + 0.75, + withReplacement = false, + 0, + UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test")))) val sortedPlan: LogicalPlan = Sort( diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java index 988742f88..b84d8ef77 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/RelationUtils.java @@ -40,7 +40,7 @@ static Optional resolveField(List relations, } static Optional tablesampleBuilder(OpenSearchPPLParser.TablesampleClauseContext context) { - if(context.percentage != null) + if(context != null && context.percentage != null) return Optional.of(new TablesampleContext(Integer.parseInt(context.percentage.getText()))); return Optional.empty(); } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala index 03d7f0ab0..dc7940056 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanAggregationQueriesTranslatorTestSuite.scala @@ -42,6 +42,29 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("test average price with tablesample(50 percent)") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source = table tablesample(50 percent)| stats avg(price) "), + context) + // SQL: SELECT avg(price) as avg_price FROM table + val star = Seq(UnresolvedStar(None)) + + val priceField = UnresolvedAttribute("price") + val tableRelation = UnresolvedRelation(Seq("table")) + val aggregateExpressions = Seq( + Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")()) + val aggregatePlan = Aggregate( + Seq(), + aggregateExpressions, + Sample(0, 0.5, withReplacement = false, 0, tableRelation)) + val expectedPlan = Project(star, aggregatePlan) + + comparePlans(expectedPlan, logPlan, false) + } + test("test average price with Alias") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext @@ -85,6 +108,33 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("test average price group by product with tablesample(50 percent)") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan(pplParser, "source = table tablesample(50 percent) | stats avg(price) by product"), + context) + // SQL: SELECT product, AVG(price) AS avg_price FROM table GROUP BY product + val star = Seq(UnresolvedStar(None)) + val productField = UnresolvedAttribute("product") + val priceField = UnresolvedAttribute("price") + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(productField, "product")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")() + val productAlias = Alias(productField, "product")() + + val aggregatePlan = + Aggregate( + groupByAttributes, + Seq(aggregateExpressions, productAlias), + Sample(0, 0.5, withReplacement = false, 0, tableRelation)) + val expectedPlan = Project(star, aggregatePlan) + + comparePlans(expectedPlan, logPlan, false) + } + test("test average price group by product and filter") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext @@ -146,6 +196,41 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite val expectedPlan = Project(star, sortedPlan) comparePlans(expectedPlan, logPlan, false) } + + test("test average price group by product and filter sorted with tablesample(50 percent)") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table tablesample(50 percent) | where country ='USA' | stats avg(price) by product | sort product"), + context) + // SQL: SELECT product, AVG(price) AS avg_price FROM table GROUP BY product + val star = Seq(UnresolvedStar(None)) + val productField = UnresolvedAttribute("product") + val priceField = UnresolvedAttribute("price") + val countryField = UnresolvedAttribute("country") + val table = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(productField, "product")()) + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(priceField), isDistinct = false), "avg(price)")() + val productAlias = Alias(productField, "product")() + + val filterExpr = EqualTo(countryField, Literal("USA")) + val filterPlan = Filter(filterExpr, Sample(0, 0.5, withReplacement = false, 0, table)) + + val aggregatePlan = + Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan) + val sortedPlan: LogicalPlan = + Sort( + Seq(SortOrder(UnresolvedAttribute("product"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + comparePlans(expectedPlan, logPlan, false) + } + test("create ppl simple avg age by span of interval of 10 years query test ") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -215,6 +300,36 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + + test( + "create ppl simple avg age by span of interval of 10 years by country query test with tablesample(50 percent)") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table tablesample(50 percent) | stats avg(age) by span(age, 10) as age_span, country"), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val ageField = UnresolvedAttribute("age") + val tableRelation = UnresolvedRelation(Seq("table")) + val countryField = UnresolvedAttribute("country") + val countryAlias = Alias(countryField, "country")() + + val aggregateExpressions = + Alias(UnresolvedFunction(Seq("AVG"), Seq(ageField), isDistinct = false), "avg(age)")() + val span = Alias( + Multiply(Floor(Divide(UnresolvedAttribute("age"), Literal(10))), Literal(10)), + "age_span")() + val aggregatePlan = Aggregate( + Seq(countryAlias, span), + Seq(aggregateExpressions, countryAlias, span), + Sample(0, 0.5, withReplacement = false, 0, tableRelation)) + val expectedPlan = Project(star, aggregatePlan) + + comparePlans(expectedPlan, logPlan, false) + } + test("create ppl query count sales by weeks window and productId with sorting test") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -290,6 +405,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite // Compare the two plans comparePlans(expectedPlan, logPlan, false) } + test("create ppl query count status amount by day window and group by status test") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -324,6 +440,43 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite // Compare the two plans comparePlans(expectedPlan, logPlan, false) } + + test( + "create ppl query count status amount by day window and group by status test with tablesample(50 percent)") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table tablesample(50 percent) | stats sum(status) by span(@timestamp, 1d) as status_count_by_day, status | head 100"), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val status = Alias(UnresolvedAttribute("status"), "status")() + val statusAmount = UnresolvedAttribute("status") + val table = UnresolvedRelation(Seq("table")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("`@timestamp`"), + TimeWindow.parseExpression(Literal("1 day")), + TimeWindow.parseExpression(Literal("1 day")), + 0), + "status_count_by_day")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("SUM"), Seq(statusAmount), isDistinct = false), + "sum(status)")() + val aggregatePlan = Aggregate( + Seq(status, windowExpression), + Seq(aggregateExpressions, status, windowExpression), + Sample(0, 0.5, withReplacement = false, 0, table)) + val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan)) + val expectedPlan = Project(star, planWithLimit) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + test( "create ppl query count only error (status >= 400) status amount by day window and group by status test") { val context = new CatalystPlanContext @@ -598,6 +751,38 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("test price 50th percentile group by product sorted with tablesample(50 percent)") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table tablesample(50 percent) | stats percentile(price, 50) by product | sort product"), + context) + val star = Seq(UnresolvedStar(None)) + val priceField = UnresolvedAttribute("price") + val productField = UnresolvedAttribute("product") + val percentage = Literal(0.5) + val tableRelation = UnresolvedRelation(Seq("table")) + + val groupByAttributes = Seq(Alias(productField, "product")()) + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("PERCENTILE"), Seq(priceField, percentage), isDistinct = false), + "percentile(price, 50)")() + val productAlias = Alias(productField, "product")() + + val aggregatePlan = + Aggregate( + groupByAttributes, + Seq(aggregateExpressions, productAlias), + Sample(0, 0.5, withReplacement = false, 0, tableRelation)) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(productField, Ascending)), global = true, aggregatePlan) + val expectedPlan = Project(star, sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + test("test price 20th percentile with alias and filter") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -776,6 +961,30 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("test distinct count product with alias and filter with tablesample(50 percent)") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table tablesample(50 percent)| where price > 100 | stats distinct_count(product) as dc_product"), + context) + val star = Seq(UnresolvedStar(None)) + val productField = UnresolvedAttribute("product") + val priceField = UnresolvedAttribute("price") + val tableRelation = UnresolvedRelation(Seq("table")) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(productField), isDistinct = true), + "dc_product")()) + val filterExpr = GreaterThan(priceField, Literal(100)) + val filterPlan = Filter(filterExpr, Sample(0, 0.5, withReplacement = false, 0, tableRelation)) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filterPlan) + val expectedPlan = Project(star, aggregatePlan) + + comparePlans(expectedPlan, logPlan, false) + } + test("test distinct count age by span of interval of 10 years query with sort ") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( @@ -838,6 +1047,42 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test( + "test distinct count status by week window and group by status with limit with tablesample(50 percent)") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table tablesample(50 percent) | stats distinct_count(status) by span(@timestamp, 1w) as status_count_by_week, status | head 100"), + context) + // Define the expected logical plan + val star = Seq(UnresolvedStar(None)) + val status = Alias(UnresolvedAttribute("status"), "status")() + val statusCount = UnresolvedAttribute("status") + val table = UnresolvedRelation(Seq("table")) + + val windowExpression = Alias( + TimeWindow( + UnresolvedAttribute("`@timestamp`"), + TimeWindow.parseExpression(Literal("1 week")), + TimeWindow.parseExpression(Literal("1 week")), + 0), + "status_count_by_week")() + + val aggregateExpressions = + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(statusCount), isDistinct = true), + "distinct_count(status)")() + val aggregatePlan = Aggregate( + Seq(status, windowExpression), + Seq(aggregateExpressions, status, windowExpression), + Sample(0, 0.5, withReplacement = false, 0, table)) + val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan)) + val expectedPlan = Project(star, planWithLimit) + // Compare the two plans + comparePlans(expectedPlan, logPlan, false) + } + test("multiple stats - test average price and average age") { val context = new CatalystPlanContext val logPlan = @@ -959,4 +1204,47 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + + test("multiple levels stats with tablesample(50 percent)") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "source = table tablesample(50 percent) | stats avg(response_time) as avg_response_time by host, service | stats avg(avg_response_time) as avg_host_response_time by service"), + context) + val star = Seq(UnresolvedStar(None)) + val hostField = UnresolvedAttribute("host") + val serviceField = UnresolvedAttribute("service") + val ageField = UnresolvedAttribute("age") + val responseTimeField = UnresolvedAttribute("response_time") + val tableRelation = UnresolvedRelation(Seq("table")) + val hostAlias = Alias(hostField, "host")() + val serviceAlias = Alias(serviceField, "service")() + + val groupByAttributes1 = Seq(Alias(hostField, "host")(), Alias(serviceField, "service")()) + val aggregateExpressions1 = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(responseTimeField), isDistinct = false), + "avg_response_time")() + val responseTimeAlias = Alias(responseTimeField, "response_time")() + val aggregatePlan1 = + Aggregate( + groupByAttributes1, + Seq(aggregateExpressions1, hostAlias, serviceAlias), + Sample(0, 0.5, withReplacement = false, 0, tableRelation)) + + val avgResponseTimeField = UnresolvedAttribute("avg_response_time") + val groupByAttributes2 = Seq(Alias(serviceField, "service")()) + val aggregateExpressions2 = + Alias( + UnresolvedFunction(Seq("AVG"), Seq(avgResponseTimeField), isDistinct = false), + "avg_host_response_time")() + + val aggregatePlan2 = + Aggregate(groupByAttributes2, Seq(aggregateExpressions2, serviceAlias), aggregatePlan1) + + val expectedPlan = Project(star, aggregatePlan2) + + comparePlans(expectedPlan, logPlan, false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanCorrelationQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanCorrelationQueriesTranslatorTestSuite.scala index ea3a8cf39..12bac9c25 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanCorrelationQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanCorrelationQueriesTranslatorTestSuite.scala @@ -98,7 +98,7 @@ class PPLLogicalPlanCorrelationQueriesTranslatorTestSuite plan( pplParser, s""" - | source = $testTable1, $testTable2| where year = 2023 AND month = 4 | correlate exact fields(name) scope(month, 1W) mapping($testTable1.name = $testTable2.name) + | source = $testTable1, $testTable2 | where year = 2023 AND month = 4 | correlate exact fields(name) scope(month, 1W) mapping($testTable1.name = $testTable2.name) | """.stripMargin), context) diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala index 3e2b3cc30..e09429049 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanEvalTranslatorTestSuite.scala @@ -14,7 +14,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, ExprId, Literal, NamedExpression, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{Project, Sample, Sort} class PPLLogicalPlanEvalTranslatorTestSuite extends SparkFunSuite @@ -80,6 +80,26 @@ class PPLLogicalPlanEvalTranslatorTestSuite comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + test("test eval expressions with sort and with tablesample(50 percent)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=t tablesample(50 percent) | eval a = 1, b = 1 | sort - a | fields b"), + context) + + val evalProjectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(1), "b")()) + val evalProject = Project( + evalProjectList, + Sample(0, 0.5, withReplacement = false, 0, UnresolvedRelation(Seq("t")))) + val sortOrder = SortOrder(UnresolvedAttribute("a"), Descending, Seq.empty) + val sort = Sort(seq(sortOrder), global = true, evalProject) + val expectedPlan = Project(seq(UnresolvedAttribute("b")), sort) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + test("test eval expressions with multiple recursive sort") { val context = new CatalystPlanContext val logPlan = diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanInSubqueryTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanInSubqueryTranslatorTestSuite.scala index 03bcdd623..400ad510f 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanInSubqueryTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanInSubqueryTranslatorTestSuite.scala @@ -14,7 +14,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, EqualTo, GreaterThanOrEqual, InSubquery, LessThan, ListQuery, Literal, Not, SortOrder} import org.apache.spark.sql.catalyst.plans.{Inner, PlanTest} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, JoinHint, LogicalPlan, Project, Sort, SubqueryAlias} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, Join, JoinHint, LogicalPlan, Project, Sample, Sort, SubqueryAlias} class PPLLogicalPlanInSubqueryTranslatorTestSuite extends SparkFunSuite @@ -56,6 +56,106 @@ class PPLLogicalPlanInSubqueryTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test("test where a in (select b from c) with only outer tablesample(50 percent)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer tablesample(50 percent) + | | where a in [ + | source = spark_catalog.default.inner | fields b + | ] + | | sort - a + | | fields a, c + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("a")), + ListQuery(Project(Seq(UnresolvedAttribute("b")), inner))), + Sample(0, 0.5, withReplacement = false, 0, outer)) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, inSubquery) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("c")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test("test where a in (select b from c) with only inner tablesample(50 percent)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | where a in [ + | source = spark_catalog.default.inner tablesample(50 percent) | fields b + | ] + | | sort - a + | | fields a, c + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("a")), + ListQuery( + Project( + Seq(UnresolvedAttribute("b")), + Sample(0, 0.5, withReplacement = false, 0, inner)))), + outer) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, inSubquery) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("c")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + + test( + "test where a in (select b from c) with both inner & outer tables tablesample(50 percent)") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer tablesample(50 percent) + | | where a in [ + | source = spark_catalog.default.inner tablesample(50 percent) | fields b + | ] + | | sort - a + | | fields a, c + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val inSubquery = + Filter( + InSubquery( + Seq(UnresolvedAttribute("a")), + ListQuery( + Project( + Seq(UnresolvedAttribute("b")), + Sample(0, 0.5, withReplacement = false, 0, inner)))), + Sample(0, 0.5, withReplacement = false, 0, outer)) + val sortedPlan: LogicalPlan = + Sort(Seq(SortOrder(UnresolvedAttribute("a"), Descending)), global = true, inSubquery) + val expectedPlan = + Project(Seq(UnresolvedAttribute("a"), UnresolvedAttribute("c")), sortedPlan) + + comparePlans(expectedPlan, logPlan, false) + } + test("test where (a) in (select b from c)") { val context = new CatalystPlanContext val logPlan = diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala index 3ceff7735..6f3fc78cc 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanJoinTranslatorTestSuite.scala @@ -13,7 +13,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, EqualTo, GreaterThan, LessThan, Literal, Not, SortOrder} import org.apache.spark.sql.catalyst.plans.{Cross, FullOuter, Inner, LeftAnti, LeftOuter, LeftSemi, PlanTest, RightOuter} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, JoinHint, LocalLimit, Project, Sort, SubqueryAlias} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, Join, JoinHint, LocalLimit, Project, Sample, Sort, SubqueryAlias} class PPLLogicalPlanJoinTranslatorTestSuite extends SparkFunSuite @@ -48,6 +48,63 @@ class PPLLogicalPlanJoinTranslatorTestSuite comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + test( + "test two-tables inner join: join condition with aliases with left side tablesample(50 percent)") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 tablesample(50 percent)| JOIN left = l right = r ON l.id = r.id $testTable2 + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val leftPlan = SubqueryAlias("l", Sample(0, 0.5, withReplacement = false, 0, table1)) + val rightPlan = SubqueryAlias("r", table2) + val joinCondition = EqualTo(UnresolvedAttribute("l.id"), UnresolvedAttribute("r.id")) + val joinPlan = Join(leftPlan, rightPlan, Inner, Some(joinCondition), JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test( + "test two-tables inner join: join condition with aliases with right side tablesample(50 percent)") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 | JOIN left = l right = r ON l.id = r.id $testTable2 tablesample(50 percent) + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val leftPlan = SubqueryAlias("l", table1) + val rightPlan = SubqueryAlias("r", Sample(0, 0.5, withReplacement = false, 0, table2)) + val joinCondition = EqualTo(UnresolvedAttribute("l.id"), UnresolvedAttribute("r.id")) + val joinPlan = Join(leftPlan, rightPlan, Inner, Some(joinCondition), JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test( + "test two-tables inner join: join condition with aliases with both sides tablesample(50 percent)") { + val context = new CatalystPlanContext + val logPlan = plan( + pplParser, + s""" + | source = $testTable1 tablesample(50 percent) | JOIN left = l right = r ON l.id = r.id $testTable2 tablesample(50 percent) + | """.stripMargin) + val logicalPlan = planTransformer.visit(logPlan, context) + val table1 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test1")) + val table2 = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test2")) + val leftPlan = SubqueryAlias("l", Sample(0, 0.5, withReplacement = false, 0, table1)) + val rightPlan = SubqueryAlias("r", Sample(0, 0.5, withReplacement = false, 0, table2)) + val joinCondition = EqualTo(UnresolvedAttribute("l.id"), UnresolvedAttribute("r.id")) + val joinPlan = Join(leftPlan, rightPlan, Inner, Some(joinCondition), JoinHint.NONE) + val expectedPlan = Project(Seq(UnresolvedStar(None)), joinPlan) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("test two-tables inner join: join condition with table names") { val context = new CatalystPlanContext val logPlan = plan( diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanNestedQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanNestedQueriesTranslatorTestSuite.scala index 263c76612..f8d6746d4 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanNestedQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanNestedQueriesTranslatorTestSuite.scala @@ -158,6 +158,29 @@ class PPLLogicalPlanNestedQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, false) } + test( + "Search multiple tables - translated into union call - nested fields expected to exist in both tables with table tablesample(50 percent)") { + val context = new CatalystPlanContext + val logPlan = planTransformer.visit( + plan( + pplParser, + "search source=table1, table2 tablesample(50 percent) | fields A.nested1, B.nested1"), + context) + + val table1 = UnresolvedRelation(Seq("table1")) + val table2 = UnresolvedRelation(Seq("table2")) + + val allFields1 = Seq(UnresolvedAttribute("A.nested1"), UnresolvedAttribute("B.nested1")) + val allFields2 = Seq(UnresolvedAttribute("A.nested1"), UnresolvedAttribute("B.nested1")) + + val projectedTable1 = Project(allFields1, Sample(0, 0.5, withReplacement = false, 0, table1)) + val projectedTable2 = Project(allFields2, Sample(0, 0.5, withReplacement = false, 0, table2)) + val expectedPlan = + Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true) + + comparePlans(expectedPlan, logPlan, false) + } + test( "Search multiple tables with FQN - translated into union call - nested fields expected to exist in both tables ") { val context = new CatalystPlanContext diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanScalarSubqueryTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanScalarSubqueryTranslatorTestSuite.scala index c76e7e538..2d2829e3a 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanScalarSubqueryTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanScalarSubqueryTranslatorTestSuite.scala @@ -132,6 +132,90 @@ class PPLLogicalPlanScalarSubqueryTranslatorTestSuite comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + test( + "test uncorrelated scalar subquery in select and where with outer tablesample(50 percent)") { + // select (select max(c) from inner), a from outer where b > (select min(c) from inner) + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer tablesample(50 percent) + | | eval max_c = [ + | source = spark_catalog.default.inner | stats max(c) + | ] + | | where b > [ + | source = spark_catalog.default.inner | stats min(c) + | ] + | | fields max_c, a + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val maxAgg = Seq( + Alias( + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("c")), isDistinct = false), + "max(c)")()) + val minAgg = Seq( + Alias( + UnresolvedFunction(Seq("MIN"), Seq(UnresolvedAttribute("c")), isDistinct = false), + "min(c)")()) + val maxAggPlan = Aggregate(Seq(), maxAgg, inner) + val minAggPlan = Aggregate(Seq(), minAgg, inner) + val maxScalarSubqueryExpr = ScalarSubquery(maxAggPlan) + val minScalarSubqueryExpr = ScalarSubquery(minAggPlan) + + val evalProjectList = Seq(UnresolvedStar(None), Alias(maxScalarSubqueryExpr, "max_c")()) + val evalProject = Project(evalProjectList, Sample(0, 0.5, withReplacement = false, 0, outer)) + val filter = Filter(GreaterThan(UnresolvedAttribute("b"), minScalarSubqueryExpr), evalProject) + val expectedPlan = + Project(Seq(UnresolvedAttribute("max_c"), UnresolvedAttribute("a")), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test( + "test uncorrelated scalar subquery in select and where with inner tablesample(50 percent) for max_c eval") { + // select (select max(c) from inner), a from outer where b > (select min(c) from inner) + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer + | | eval max_c = [ + | source = spark_catalog.default.inner tablesample(50 percent) | stats max(c) + | ] + | | where b > [ + | source = spark_catalog.default.inner | stats min(c) + | ] + | | fields max_c, a + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val maxAgg = Seq( + Alias( + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("c")), isDistinct = false), + "max(c)")()) + val minAgg = Seq( + Alias( + UnresolvedFunction(Seq("MIN"), Seq(UnresolvedAttribute("c")), isDistinct = false), + "min(c)")()) + val maxAggPlan = Aggregate(Seq(), maxAgg, Sample(0, 0.5, withReplacement = false, 0, inner)) + val minAggPlan = Aggregate(Seq(), minAgg, inner) + val maxScalarSubqueryExpr = ScalarSubquery(maxAggPlan) + val minScalarSubqueryExpr = ScalarSubquery(minAggPlan) + + val evalProjectList = Seq(UnresolvedStar(None), Alias(maxScalarSubqueryExpr, "max_c")()) + val evalProject = Project(evalProjectList, outer) + val filter = Filter(GreaterThan(UnresolvedAttribute("b"), minScalarSubqueryExpr), evalProject) + val expectedPlan = + Project(Seq(UnresolvedAttribute("max_c"), UnresolvedAttribute("a")), filter) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + test("test correlated scalar subquery in select") { // select (select max(c) from inner where b = d), a from outer val context = new CatalystPlanContext @@ -164,6 +248,40 @@ class PPLLogicalPlanScalarSubqueryTranslatorTestSuite comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + test("test correlated scalar subquery in select with both tables tablesample(50 percent)") { + // select (select max(c) from inner where b = d), a from outer + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + s""" + | source = spark_catalog.default.outer tablesample(50 percent) + | | eval max_c = [ + | source = spark_catalog.default.inner tablesample(50 percent) | where b = d | stats max(c) + | ] + | | fields max_c, a + | """.stripMargin), + context) + val outer = UnresolvedRelation(Seq("spark_catalog", "default", "outer")) + val inner = UnresolvedRelation(Seq("spark_catalog", "default", "inner")) + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("MAX"), Seq(UnresolvedAttribute("c")), isDistinct = false), + "max(c)")()) + val filter = Filter( + EqualTo(UnresolvedAttribute("b"), UnresolvedAttribute("d")), + Sample(0, 0.5, withReplacement = false, 0, inner)) + val aggregatePlan = Aggregate(Seq(), aggregateExpressions, filter) + val scalarSubqueryExpr = ScalarSubquery(aggregatePlan) + val evalProjectList = Seq(UnresolvedStar(None), Alias(scalarSubqueryExpr, "max_c")()) + val evalProject = Project(evalProjectList, Sample(0, 0.5, withReplacement = false, 0, outer)) + val expectedPlan = + Project(Seq(UnresolvedAttribute("max_c"), UnresolvedAttribute("a")), evalProject) + + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + test("test correlated scalar subquery in select with non-equal") { // select (select max(c) from inner where b > d), a from outer val context = new CatalystPlanContext diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala index 90f90fde7..b8e64ff0a 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala @@ -64,7 +64,7 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, "source=accounts | rare address tablesample(50 percent)"), + plan(pplParser, "source=accounts tablesample(50 percent) | rare address "), context) val addressField = UnresolvedAttribute("address") val tableRelation = UnresolvedRelation(Seq("accounts")) @@ -78,7 +78,10 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite addressField) val aggregatePlan = - Aggregate(Seq(addressField), aggregateExpressions, tableRelation) + Aggregate( + Seq(addressField), + aggregateExpressions, + Sample(0, 0.5, withReplacement = false, 0, tableRelation)) val sortedPlan: LogicalPlan = Sort( @@ -166,6 +169,44 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + test("test simple top command with a single field tablesample(50 percent) ") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=accounts tablesample(50 percent) | top address"), + context) + val addressField = UnresolvedAttribute("address") + val tableRelation = UnresolvedRelation(Seq("accounts")) + + val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None)) + + val aggregateExpressions = Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + addressField) + + val aggregatePlan = + Aggregate( + Seq(addressField), + aggregateExpressions, + Sample(0, 0.5, withReplacement = false, 0, tableRelation)) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Descending)), + global = true, + aggregatePlan) + val expectedPlan = Project(projectList, sortedPlan) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + test("test simple top 1 command by age field") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext @@ -242,4 +283,44 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite comparePlans(expectedPlan, logPlan, checkAnalysis = false) } + test("create ppl top 3 countries by occupation field query test with tablesample(25 percent)") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=accounts tablesample(25 percent) | top 3 country by occupation"), + context) + + val tableRelation = UnresolvedRelation(Seq("accounts")) + val countryField = UnresolvedAttribute("country") + val occupationField = UnresolvedAttribute("occupation") + val occupationFieldAlias = Alias(occupationField, "occupation")() + + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")() + val aggregateExpressions = Seq(countExpr, countryField, occupationFieldAlias) + val aggregatePlan = + Aggregate( + Seq(countryField, occupationFieldAlias), + aggregateExpressions, + Sample(0, 0.25, withReplacement = false, 0, tableRelation)) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")(), + Descending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + }