Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support more PPL builtin functions by adding a name mapping #504

Merged
merged 2 commits into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,28 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
| """.stripMargin)
}

protected def createNullableStateCountryTable(testTable: String): Unit = {
sql(s"""
| CREATE TABLE $testTable
| (
| name STRING,
| age INT,
| state STRING,
| country STRING
| )
| USING $tableType $tableOptions
|""".stripMargin)

sql(s"""
| INSERT INTO $testTable
| VALUES ('Jake', 70, 'California', 'USA'),
| ('Hello', 30, 'New York', 'USA'),
| ('John', 25, 'Ontario', 'Canada'),
| ('Jane', 20, 'Quebec', 'Canada'),
| (null, 10, null, 'Canada')
| """.stripMargin)
}

protected def createOccupationTable(testTable: String): Unit = {
sql(s"""
| CREATE TABLE $testTable
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,57 @@

package org.opensearch.sql.ppl.utils;

import com.google.common.collect.ImmutableMap;
import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;

import java.util.List;
import java.util.Locale;
import java.util.Map;

import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq;
import static scala.Option.empty;

public interface BuiltinFunctionTranslator {

/**
* The name mapping between PPL builtin functions to Spark builtin functions.
*/
static final Map<String, String> SPARK_BUILTIN_FUNCTION_NAME_MAPPING = new ImmutableMap.Builder<String, String>()
// arithmetic operators
.put(BuiltinFunctionName.ADD.name().toLowerCase(Locale.ROOT), "+")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

plz add static import for BuiltinFunctionName.* (shortens the map declaration)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

.put(BuiltinFunctionName.SUBTRACT.name().toLowerCase(Locale.ROOT), "-")
.put(BuiltinFunctionName.MULTIPLY.name().toLowerCase(Locale.ROOT), "*")
.put(BuiltinFunctionName.DIVIDE.name().toLowerCase(Locale.ROOT), "/")
.put(BuiltinFunctionName.MODULUS.name().toLowerCase(Locale.ROOT), "%")
// time functions
.put(BuiltinFunctionName.DAY_OF_WEEK.name().toLowerCase(Locale.ROOT), "dayofweek")
.put(BuiltinFunctionName.DAY_OF_MONTH.name().toLowerCase(Locale.ROOT), "dayofmonth")
.put(BuiltinFunctionName.DAY_OF_YEAR.name().toLowerCase(Locale.ROOT), "dayofyear")
.put(BuiltinFunctionName.WEEK_OF_YEAR.name().toLowerCase(Locale.ROOT), "weekofyear")
.put(BuiltinFunctionName.WEEK.name().toLowerCase(Locale.ROOT), "weekofyear")
.put(BuiltinFunctionName.MONTH_OF_YEAR.name().toLowerCase(Locale.ROOT), "month")
.put(BuiltinFunctionName.HOUR_OF_DAY.name().toLowerCase(Locale.ROOT), "hour")
.put(BuiltinFunctionName.MINUTE_OF_HOUR.name().toLowerCase(Locale.ROOT), "minute")
.put(BuiltinFunctionName.SECOND_OF_MINUTE.name().toLowerCase(Locale.ROOT), "second")
.put(BuiltinFunctionName.SUBDATE.name().toLowerCase(Locale.ROOT), "date_sub") // only maps subdate(date, days)
.put(BuiltinFunctionName.ADDDATE.name().toLowerCase(Locale.ROOT), "date_add") // only maps adddate(date, days)
.put(BuiltinFunctionName.DATEDIFF.name().toLowerCase(Locale.ROOT), "datediff")
.put(BuiltinFunctionName.LOCALTIME.name().toLowerCase(Locale.ROOT), "localtimestamp")
// condition functions
.put(BuiltinFunctionName.IS_NULL.name().toLowerCase(Locale.ROOT), "isnull")
.put(BuiltinFunctionName.IS_NOT_NULL.name().toLowerCase(Locale.ROOT), "isnotnull")
.build();

static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List<Expression> args) {
if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) {
// TODO change it when UDF is supported
// TODO should we support more functions which are not PPL builtin functions. E.g Spark builtin functions
throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL");
} else {
String name = BuiltinFunctionName.of(function.getFuncName()).get().name().toLowerCase(Locale.ROOT);
name = SPARK_BUILTIN_FUNCTION_NAME_MAPPING.getOrDefault(name, name);
return new UnresolvedFunction(seq(name), seq(args), false, empty(),false);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,26 @@

package org.opensearch.flint.spark.ppl

import org.junit.Assert.assertEquals
import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, GreaterThanOrEqual, Literal, Multiply, SortOrder, TimeWindow}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._

class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
extends SparkFunSuite
with PlanTest
with LogicalPlanTestUtils
with Matchers {

private val planTransformer = new CatalystQueryPlanVisitor()
private val pplParser = new PPLSyntaxParser()

test("test average price ") {
test("test average price") {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan =
Expand All @@ -38,10 +39,10 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assertEquals(compareByString(expectedPlan), compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

ignore("test average price with Alias") {
test("test average price with Alias") {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
Expand All @@ -57,7 +58,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assertEquals(compareByString(expectedPlan), compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

test("test average price group by product ") {
Expand All @@ -81,7 +82,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assertEquals(compareByString(expectedPlan), compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

test("test average price group by product and filter") {
Expand Down Expand Up @@ -109,7 +110,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

assertEquals(compareByString(expectedPlan), compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

test("test average price group by product and filter sorted") {
Expand Down Expand Up @@ -144,7 +145,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
global = true,
aggregatePlan)
val expectedPlan = Project(star, sortedPlan)
assertEquals(compareByString(expectedPlan), compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}
test("create ppl simple avg age by span of interval of 10 years query test ") {
val context = new CatalystPlanContext
Expand All @@ -164,7 +165,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assert(compareByString(expectedPlan) === compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

test("create ppl simple avg age by span of interval of 10 years query with sort test ") {
Expand All @@ -190,7 +191,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, aggregatePlan)
val expectedPlan = Project(star, sortedPlan)

assert(compareByString(expectedPlan) === compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

test("create ppl simple avg age by span of interval of 10 years by country query test ") {
Expand Down Expand Up @@ -219,7 +220,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assert(compareByString(expectedPlan) === compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}
test("create ppl query count sales by weeks window and productId with sorting test") {
val context = new CatalystPlanContext
Expand Down Expand Up @@ -257,7 +258,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite

val expectedPlan = Project(star, sortedPlan)
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

test("create ppl query count sales by days window and productId with sorting test") {
Expand Down Expand Up @@ -296,7 +297,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
aggregatePlan)
val expectedPlan = Project(star, sortedPlan)
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}
test("create ppl query count status amount by day window and group by status test") {
val context = new CatalystPlanContext
Expand Down Expand Up @@ -331,7 +332,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan))
val expectedPlan = Project(star, planWithLimit)
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}
test(
"create ppl query count only error (status >= 400) status amount by day window and group by status test") {
Expand Down Expand Up @@ -368,7 +369,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan))
val expectedPlan = Project(star, planWithLimit)
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@

package org.opensearch.flint.spark.ppl

import org.junit.Assert.assertEquals
import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Descending, Literal, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType

class PPLLogicalPlanBasicQueriesTranslatorTestSuite
extends SparkFunSuite
with PlanTest
with LogicalPlanTestUtils
with Matchers {

Expand All @@ -31,7 +31,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table")))
assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)
}

test("test simple search with escaped table name") {
Expand All @@ -41,7 +41,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table")))
assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)
}

test("test simple search with schema.table and no explicit fields (defaults to all fields)") {
Expand All @@ -51,7 +51,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table")))
assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)

}

Expand All @@ -62,7 +62,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite

val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A"))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table")))
assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)
}

test("test simple search with only one table with one field projected") {
Expand All @@ -72,7 +72,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite

val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A"))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table")))
assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)
}

test("test simple search with only one table with two fields projected") {
Expand All @@ -82,7 +82,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
val table = UnresolvedRelation(Seq("t"))
val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"))
val expectedPlan = Project(projectList, table)
assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)
}

test("test simple search with one table with two fields projected sorted by one field") {
Expand All @@ -97,7 +97,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
val sorted = Sort(sortOrder, true, table)
val expectedPlan = Project(projectList, sorted)

assert(compareByString(expectedPlan) === compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

test(
Expand All @@ -111,7 +111,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
val planWithLimit =
GlobalLimit(Literal(5), LocalLimit(Literal(5), Project(projectList, table)))
val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit)
assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)
}

test(
Expand All @@ -129,8 +129,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite

val planWithLimit = GlobalLimit(Literal(5), LocalLimit(Literal(5), projectAB))
val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit)

assertEquals(compareByString(expectedPlan), compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

test(
Expand All @@ -152,7 +151,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
val expectedPlan =
Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true)

assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)
}

test("Search multiple tables - translated into union call with fields") {
Expand All @@ -172,6 +171,6 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
val expectedPlan =
Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true)

assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)
}
}
Loading
Loading