Skip to content

Commit

Permalink
fix failing test and add more tests
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Trochimiak <[email protected]>
  • Loading branch information
kt-eliatra committed Sep 27, 2024
1 parent 85d8814 commit 2bc1d8d
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@ package org.opensearch.flint.spark.ppl

import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq

import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, LessThan, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort}
import org.apache.spark.sql.catalyst.expressions.Alias
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLRenameITSuite
Expand Down Expand Up @@ -97,25 +97,56 @@ class FlintSparkPPLRenameITSuite

test("test renamed fields without fields command") {
val frame = sql(s"""
| source = $testTable | rename state as _state, country as _country
| source = $testTable | rename age as user_age, country as user_country
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] = Array(
Row("Jake", 70, "California", "USA", 2023, 4, "California", "USA"),
Row("Hello", 30, "New York", "USA", 2023, 4, "New York", "USA"),
Row("John", 25, "Ontario", "Canada", 2023, 4, "Ontario", "Canada"),
Row("Jane", 20, "Quebec", "Canada", 2023, 4, "Quebec", "Canada"))
Row("Jake", 70, "California", "USA", 2023, 4, 70, "USA"),
Row("Hello", 30, "New York", "USA", 2023, 4, 30, "USA"),
Row("John", 25, "Ontario", "Canada", 2023, 4, 25, "Canada"),
Row("Jane", 20, "Quebec", "Canada", 2023, 4, 20, "Canada"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val renameProjectList = Seq(
UnresolvedStar(None),
Alias(UnresolvedAttribute("state"), "_state")(),
Alias(UnresolvedAttribute("country"), "_country")())
Alias(UnresolvedAttribute("age"), "user_age")(),
Alias(UnresolvedAttribute("country"), "user_country")())
val expectedPlan = Project(seq(UnresolvedStar(None)), Project(renameProjectList, table))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test renamed field used in aggregation") {
val frame = sql(s"""
| source = $testTable | rename age as user_age | stats avg(user_age) by country
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] = Array(Row(22.5, "Canada"), Row(50.0, "USA"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0))
assert(results.sorted.sameElements(expectedResults.sorted))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val renameProjectList =
Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("age"), "user_age")())
val aggregateExpressions =
Seq(
Alias(
UnresolvedFunction(
Seq("AVG"),
Seq(UnresolvedAttribute("user_age")),
isDistinct = false),
"avg(user_age)")(),
Alias(UnresolvedAttribute("country"), "country")())
val aggregatePlan = Aggregate(
Seq(Alias(UnresolvedAttribute("country"), "country")()),
aggregateExpressions,
Project(renameProjectList, table))
val expectedPlan = Project(seq(UnresolvedStar(None)), aggregatePlan)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.flint.spark.ppl

import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq
import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, ExprId, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical.{Project, Sort}

class PPLLogicalPlanRenameTranslatorTestSuite
extends SparkFunSuite
with PlanTest
with LogicalPlanTestUtils
with Matchers {

private val planTransformer = new CatalystQueryPlanVisitor()
private val pplParser = new PPLSyntaxParser()

test("test renamed fields not included in fields expressions") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(pplParser, "source=t | rename a as r_a, b as r_b | fields c", false),
context)
val renameProjectList: Seq[NamedExpression] =
Seq(
UnresolvedStar(None),
Alias(UnresolvedAttribute("a"), "r_a")(),
Alias(UnresolvedAttribute("b"), "r_b")())
val expectedPlan = Project(
seq(UnresolvedAttribute("c")),
Project(renameProjectList, UnresolvedRelation(Seq("t"))))
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

test("test renamed fields included in fields expression") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(pplParser, "source=t | rename a as r_a, b as r_b | fields r_a, r_b, c", false),
context)

val renameProjectList: Seq[NamedExpression] =
Seq(
UnresolvedStar(None),
Alias(UnresolvedAttribute("a"), "r_a")(),
Alias(UnresolvedAttribute("b"), "r_b")())
val expectedPlan = Project(
seq(UnresolvedAttribute("r_a"), UnresolvedAttribute("r_b"), UnresolvedAttribute("c")),
Project(renameProjectList, UnresolvedRelation(Seq("t"))))
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

test("test renamed fields without fields command") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(pplParser, "source=t | rename a as r_a, b as r_b", false),
context)

val renameProjectList: Seq[NamedExpression] =
Seq(
UnresolvedStar(None),
Alias(UnresolvedAttribute("a"), "r_a")(),
Alias(UnresolvedAttribute("b"), "r_b")())
val expectedPlan =
Project(seq(UnresolvedStar(None)), Project(renameProjectList, UnresolvedRelation(Seq("t"))))
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

test("test renamed fields with sort") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(pplParser, "source=t | rename a as r_a, b as r_b | sort - r_a | fields r_b", false),
context)

val renameProjectList: Seq[NamedExpression] =
Seq(
UnresolvedStar(None),
Alias(UnresolvedAttribute("a"), "r_a")(),
Alias(UnresolvedAttribute("b"), "r_b")())
val renameProject = Project(renameProjectList, UnresolvedRelation(Seq("t")))
val sortOrder = SortOrder(UnresolvedAttribute("r_a"), Descending, Seq.empty)
val sort = Sort(seq(sortOrder), global = true, renameProject)
val expectedPlan = Project(seq(UnresolvedAttribute("r_b")), sort)
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}

test("test rename eval expression output") {
val context = new CatalystPlanContext
val logPlan =
planTransformer.visit(
plan(
pplParser,
"source=t | eval a = RAND() | rename a as eval_rand | fields eval_rand",
false),
context)

val evalProjectList: Seq[NamedExpression] = Seq(
UnresolvedStar(None),
Alias(UnresolvedFunction("rand", Seq.empty, isDistinct = false), "a")(
exprId = ExprId(0),
qualifier = Seq.empty))
val evalProject = Project(evalProjectList, UnresolvedRelation(Seq("t")))
val renameProjectList: Seq[NamedExpression] =
Seq(UnresolvedStar(None), Alias(UnresolvedAttribute("a"), "eval_rand")())
val expectedPlan =
Project(seq(UnresolvedAttribute("eval_rand")), Project(renameProjectList, evalProject))
comparePlans(expectedPlan, logPlan, checkAnalysis = false)
}
}

0 comments on commit 2bc1d8d

Please sign in to comment.