Skip to content

Commit

Permalink
Support more builtin functions by adding a name mapping
Browse files Browse the repository at this point in the history
Signed-off-by: Lantao Jin <[email protected]>
  • Loading branch information
LantaoJin committed Aug 1, 2024
1 parent 24d3b81 commit 8843a87
Show file tree
Hide file tree
Showing 9 changed files with 517 additions and 185 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,28 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
| """.stripMargin)
}

protected def createNullableStateCountryTable(testTable: String): Unit = {
sql(s"""
| CREATE TABLE $testTable
| (
| name STRING,
| age INT,
| state STRING,
| country STRING
| )
| USING $tableType $tableOptions
|""".stripMargin)

sql(s"""
| INSERT INTO $testTable
| VALUES ('Jake', 70, 'California', 'USA'),
| ('Hello', 30, 'New York', 'USA'),
| ('John', 25, 'Ontario', 'Canada'),
| ('Jane', 20, 'Quebec', 'Canada'),
| (null, 10, null, 'Canada')
| """.stripMargin)
}

protected def createOccupationTable(testTable: String): Unit = {
sql(s"""
| CREATE TABLE $testTable
Expand Down

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,57 @@

package org.opensearch.sql.ppl.utils;

import com.google.common.collect.ImmutableMap;
import org.apache.spark.sql.catalyst.analysis.UnresolvedFunction;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;

import java.util.List;
import java.util.Locale;
import java.util.Map;

import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq;
import static scala.Option.empty;

public interface BuiltinFunctionTranslator {

/**
* The name mapping between PPL builtin functions to Spark builtin functions.
*/
static final Map<String, String> SPARK_BUILTIN_FUNCTION_NAME_MAPPING = new ImmutableMap.Builder<String, String>()
// arithmetic operators
.put(BuiltinFunctionName.ADD.name().toLowerCase(Locale.ROOT), "+")
.put(BuiltinFunctionName.SUBTRACT.name().toLowerCase(Locale.ROOT), "-")
.put(BuiltinFunctionName.MULTIPLY.name().toLowerCase(Locale.ROOT), "*")
.put(BuiltinFunctionName.DIVIDE.name().toLowerCase(Locale.ROOT), "/")
.put(BuiltinFunctionName.MODULUS.name().toLowerCase(Locale.ROOT), "%")
// time functions
.put(BuiltinFunctionName.DAY_OF_WEEK.name().toLowerCase(Locale.ROOT), "dayofweek")
.put(BuiltinFunctionName.DAY_OF_MONTH.name().toLowerCase(Locale.ROOT), "dayofmonth")
.put(BuiltinFunctionName.DAY_OF_YEAR.name().toLowerCase(Locale.ROOT), "dayofyear")
.put(BuiltinFunctionName.WEEK_OF_YEAR.name().toLowerCase(Locale.ROOT), "weekofyear")
.put(BuiltinFunctionName.WEEK.name().toLowerCase(Locale.ROOT), "weekofyear")
.put(BuiltinFunctionName.MONTH_OF_YEAR.name().toLowerCase(Locale.ROOT), "month")
.put(BuiltinFunctionName.HOUR_OF_DAY.name().toLowerCase(Locale.ROOT), "hour")
.put(BuiltinFunctionName.MINUTE_OF_HOUR.name().toLowerCase(Locale.ROOT), "minute")
.put(BuiltinFunctionName.SECOND_OF_MINUTE.name().toLowerCase(Locale.ROOT), "second")
.put(BuiltinFunctionName.SUBDATE.name().toLowerCase(Locale.ROOT), "date_sub") // only maps subdate(date, days)
.put(BuiltinFunctionName.ADDDATE.name().toLowerCase(Locale.ROOT), "date_add") // only maps adddate(date, days)
.put(BuiltinFunctionName.DATEDIFF.name().toLowerCase(Locale.ROOT), "datediff")
.put(BuiltinFunctionName.LOCALTIME.name().toLowerCase(Locale.ROOT), "localtimestamp")
// condition functions
.put(BuiltinFunctionName.IS_NULL.name().toLowerCase(Locale.ROOT), "isnull")
.put(BuiltinFunctionName.IS_NOT_NULL.name().toLowerCase(Locale.ROOT), "isnotnull")
.build();

static Expression builtinFunction(org.opensearch.sql.ast.expression.Function function, List<Expression> args) {
if (BuiltinFunctionName.of(function.getFuncName()).isEmpty()) {
// TODO change it when UDF is supported
// TODO should we support more functions which are not PPL builtin functions. E.g Spark builtin functions
throw new UnsupportedOperationException(function.getFuncName() + " is not a builtin function of PPL");
} else {
String name = BuiltinFunctionName.of(function.getFuncName()).get().name().toLowerCase(Locale.ROOT);
name = SPARK_BUILTIN_FUNCTION_NAME_MAPPING.getOrDefault(name, name);
return new UnresolvedFunction(seq(name), seq(args), false, empty(),false);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,26 @@

package org.opensearch.flint.spark.ppl

import org.junit.Assert.assertEquals
import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Divide, EqualTo, Floor, GreaterThanOrEqual, Literal, Multiply, SortOrder, TimeWindow}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._

class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
extends SparkFunSuite
with PlanTest
with LogicalPlanTestUtils
with Matchers {

private val planTransformer = new CatalystQueryPlanVisitor()
private val pplParser = new PPLSyntaxParser()

test("test average price ") {
test("test average price") {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan =
Expand All @@ -38,10 +39,10 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assertEquals(compareByString(expectedPlan), compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

ignore("test average price with Alias") {
test("test average price with Alias") {
// if successful build ppl logical plan and translate to catalyst logical plan
val context = new CatalystPlanContext
val logPlan = planTransformer.visit(
Expand All @@ -57,7 +58,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val aggregatePlan = Aggregate(Seq(), aggregateExpressions, tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assertEquals(compareByString(expectedPlan), compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

test("test average price group by product ") {
Expand All @@ -81,7 +82,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assertEquals(compareByString(expectedPlan), compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

test("test average price group by product and filter") {
Expand Down Expand Up @@ -109,7 +110,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
Aggregate(groupByAttributes, Seq(aggregateExpressions, productAlias), filterPlan)
val expectedPlan = Project(star, aggregatePlan)

assertEquals(compareByString(expectedPlan), compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

test("test average price group by product and filter sorted") {
Expand Down Expand Up @@ -144,7 +145,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
global = true,
aggregatePlan)
val expectedPlan = Project(star, sortedPlan)
assertEquals(compareByString(expectedPlan), compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}
test("create ppl simple avg age by span of interval of 10 years query test ") {
val context = new CatalystPlanContext
Expand All @@ -164,7 +165,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val aggregatePlan = Aggregate(Seq(span), Seq(aggregateExpressions, span), tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assert(compareByString(expectedPlan) === compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

test("create ppl simple avg age by span of interval of 10 years query with sort test ") {
Expand All @@ -190,7 +191,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, aggregatePlan)
val expectedPlan = Project(star, sortedPlan)

assert(compareByString(expectedPlan) === compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

test("create ppl simple avg age by span of interval of 10 years by country query test ") {
Expand Down Expand Up @@ -219,7 +220,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
tableRelation)
val expectedPlan = Project(star, aggregatePlan)

assert(compareByString(expectedPlan) === compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}
test("create ppl query count sales by weeks window and productId with sorting test") {
val context = new CatalystPlanContext
Expand Down Expand Up @@ -257,7 +258,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite

val expectedPlan = Project(star, sortedPlan)
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

test("create ppl query count sales by days window and productId with sorting test") {
Expand Down Expand Up @@ -296,7 +297,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
aggregatePlan)
val expectedPlan = Project(star, sortedPlan)
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}
test("create ppl query count status amount by day window and group by status test") {
val context = new CatalystPlanContext
Expand Down Expand Up @@ -331,7 +332,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan))
val expectedPlan = Project(star, planWithLimit)
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}
test(
"create ppl query count only error (status >= 400) status amount by day window and group by status test") {
Expand Down Expand Up @@ -368,7 +369,7 @@ class PPLLogicalPlanAggregationQueriesTranslatorTestSuite
val planWithLimit = GlobalLimit(Literal(100), LocalLimit(Literal(100), aggregatePlan))
val expectedPlan = Project(star, planWithLimit)
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,19 @@

package org.opensearch.flint.spark.ppl

import org.junit.Assert.assertEquals
import org.opensearch.flint.spark.ppl.PlaneUtils.plan
import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor}
import org.scalatest.matchers.should.Matchers

import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Ascending, AttributeReference, Descending, Literal, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.types.IntegerType

class PPLLogicalPlanBasicQueriesTranslatorTestSuite
extends SparkFunSuite
with PlanTest
with LogicalPlanTestUtils
with Matchers {

Expand All @@ -31,7 +31,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table")))
assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)
}

test("test simple search with escaped table name") {
Expand All @@ -41,7 +41,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table")))
assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)
}

test("test simple search with schema.table and no explicit fields (defaults to all fields)") {
Expand All @@ -51,7 +51,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite

val projectList: Seq[NamedExpression] = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table")))
assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)

}

Expand All @@ -62,7 +62,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite

val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A"))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("schema", "table")))
assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)
}

test("test simple search with only one table with one field projected") {
Expand All @@ -72,7 +72,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite

val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("A"))
val expectedPlan = Project(projectList, UnresolvedRelation(Seq("table")))
assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)
}

test("test simple search with only one table with two fields projected") {
Expand All @@ -82,7 +82,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
val table = UnresolvedRelation(Seq("t"))
val projectList = Seq(UnresolvedAttribute("A"), UnresolvedAttribute("B"))
val expectedPlan = Project(projectList, table)
assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)
}

test("test simple search with one table with two fields projected sorted by one field") {
Expand All @@ -97,7 +97,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
val sorted = Sort(sortOrder, true, table)
val expectedPlan = Project(projectList, sorted)

assert(compareByString(expectedPlan) === compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

test(
Expand All @@ -111,7 +111,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
val planWithLimit =
GlobalLimit(Literal(5), LocalLimit(Literal(5), Project(projectList, table)))
val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit)
assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)
}

test(
Expand All @@ -129,8 +129,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite

val planWithLimit = GlobalLimit(Literal(5), LocalLimit(Literal(5), projectAB))
val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit)

assertEquals(compareByString(expectedPlan), compareByString(logPlan))
comparePlans(expectedPlan, logPlan, false)
}

test(
Expand All @@ -152,7 +151,7 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
val expectedPlan =
Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true)

assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)
}

test("Search multiple tables - translated into union call with fields") {
Expand All @@ -172,6 +171,6 @@ class PPLLogicalPlanBasicQueriesTranslatorTestSuite
val expectedPlan =
Union(Seq(projectedTable1, projectedTable2), byName = true, allowMissingCol = true)

assertEquals(expectedPlan, logPlan)
comparePlans(expectedPlan, logPlan, false)
}
}
Loading

0 comments on commit 8843a87

Please sign in to comment.