Skip to content

Commit

Permalink
Implementation of case function.
Browse files Browse the repository at this point in the history
Signed-off-by: Lukasz Soszynski <[email protected]>
  • Loading branch information
lukasz-soszynski-eliatra committed Sep 24, 2024
1 parent a7fe6e6 commit 4496532
Show file tree
Hide file tree
Showing 9 changed files with 158 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -548,4 +548,27 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit
| )
|)""".stripMargin)
}

protected def createTableHttpLog(testTable: String): Unit = {
sql(s"""
| CREATE TABLE $testTable
|(
| id INT,
| status_code INT,
| request_path STRING,
| timestamp STRING
|)
| USING $tableType $tableOptions
|""".stripMargin)

sql(s"""
| INSERT INTO $testTable
| VALUES (1, 200, '/home', '2023-10-01 10:00:00'),
| (2, null, '/about', '2023-10-01 10:05:00'),
| (3, 500, '/contact', '2023-10-01 10:10:00'),
| (4, 301, '/home', '2023-10-01 10:15:00'),
| (5, 200, '/services', '2023-10-01 10:20:00'),
| (6, 403, '/home', '2023-10-01 10:25:00'),
| """.stripMargin)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq

import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, LessThan, Literal, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, EqualTo, GreaterThanOrEqual, LessThan, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort}
import org.apache.spark.sql.streaming.StreamTest

Expand All @@ -21,12 +21,14 @@ class FlintSparkPPLEvalITSuite

/** Test table and index name */
private val testTable = "spark_catalog.default.flint_ppl_test"
private val testTableHttpLog = "spark_catalog.default.flint_ppl_test_http_log"

override def beforeAll(): Unit = {
super.beforeAll()

// Create test table
createPartitionedStateCountryTable(testTable)
createTableHttpLog(testTableHttpLog)
}

protected override def afterEach(): Unit = {
Expand Down Expand Up @@ -504,7 +506,63 @@ class FlintSparkPPLEvalITSuite
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("eval case function") {
val frame = sql(s"""
| source = $testTableHttpLog |
| eval status_category =
| case(status_code >= 200 AND status_code < 300, 'Success',
| status_code >= 300 AND status_code < 400, 'Redirection',
| status_code >= 400 AND status_code < 500, 'Client Error',
| status_code >= 500, 'Server Error'
| else 'Unknown'
| )
| """.stripMargin)

val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row(1, 200, "/home", "2023-10-01 10:00:00", "Success"),
Row(2, null, "/about", "2023-10-01 10:05:00", "Unknown"),
Row(3, 500, "/contact", "2023-10-01 10:10:00", "Server Error"),
Row(4, 301, "/home", "2023-10-01 10:15:00", "Redirection"),
Row(5, 200, "/services", "2023-10-01 10:20:00", "Success"),
Row(6, 403, "/home", "2023-10-01 10:25:00", "Client Error"))
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getInt(0))
assert(results.sorted.sameElements(expectedResults.sorted))
val expectedColumns =
Array[String]("id", "status_code", "request_path", "timestamp", "status_category")
assert(frame.columns.sameElements(expectedColumns))

val logicalPlan: LogicalPlan = frame.queryExecution.logical

val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test_http_log"))
val conditionValueSequence = Seq(
(graterOrEqualAndLessThan("status_code", 200, 300), Literal("Success")),
(graterOrEqualAndLessThan("status_code", 300, 400), Literal("Redirection")),
(graterOrEqualAndLessThan("status_code", 400, 500), Literal("Client Error")),
(
EqualTo(
Literal(true),
GreaterThanOrEqual(UnresolvedAttribute("status_code"), Literal(500))),
Literal("Server Error")))
val elseValue = Literal("Unknown")
val caseFunction = CaseWhen(conditionValueSequence, elseValue)
val aliasStatusCategory = Alias(caseFunction, "status_category")()
val evalProjectList = Seq(UnresolvedStar(None), aliasStatusCategory)
val evalProject = Project(evalProjectList, table)
val expectedPlan = Project(Seq(UnresolvedStar(None)), evalProject)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

private def graterOrEqualAndLessThan(fieldName: String, min: Int, max: Int) = {
val and = And(
GreaterThanOrEqual(UnresolvedAttribute("status_code"), Literal(min)),
LessThan(UnresolvedAttribute(fieldName), Literal(max)))
EqualTo(Literal(true), and)
}

// Todo excluded fields not support yet

ignore("test single eval expression with excluded fields") {
val frame = sql(s"""
| source = $testTable | eval new_field = "New Field" | fields - age
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Descending, Divide, EqualTo, Floor, GreaterThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder}
import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CaseWhen, Descending, Divide, EqualTo, Floor, GreaterThan, LessThanOrEqual, Literal, Multiply, Not, Or, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.streaming.StreamTest

Expand Down Expand Up @@ -348,4 +348,38 @@ class FlintSparkPPLFiltersITSuite
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}

test("case function used as filter") {
val frame = sql(s"""
| source = $testTable case(country = 'USA', 'The United States of America' else 'Other country') = 'The United States of America'
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("Jake", 70, "California", "USA", 2023, 4),
Row("Hello", 30, "New York", "USA", 2023, 4))

// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
val sorted = results.sorted
assert(sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val conditionValueSequence = Seq(
(
EqualTo(Literal(true), EqualTo(UnresolvedAttribute("country"), Literal("USA"))),
Literal("The United States of America")))
val elseValue = Literal("Other country")
val caseFunction = CaseWhen(conditionValueSequence, elseValue)
val filterExpr = EqualTo(caseFunction, Literal("The United States of America"))
val filterPlan = Filter(filterExpr, table)
val projectList = Seq(UnresolvedStar(None))
val expectedPlan = Project(projectList, filterPlan)
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
}
10 changes: 10 additions & 0 deletions ppl-spark-integration/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ See the next samples of PPL queries :
- `source = table | where ispresent(b)`
- `source = table | where isnull(coalesce(a, b)) | fields a,b,c | head 3`
- `source = table | where isempty(a)`
- `source = table | where case(length(a) > 6, 'True' else 'False') = 'True'`;

**Filters With Logical Conditions**
- `source = table | where c = 'test' AND a = 1 | fields a,b,c`
Expand All @@ -265,6 +266,15 @@ Assumptions: `a`, `b`, `c` are existing fields in `table`
- `source = table | eval f = ispresent(a)`
- `source = table | eval r = coalesce(a, b, c) | fields r`
- `source = table | eval e = isempty(a) | fields e`
```
source = table | eval e = eval status_category =
case(a >= 200 AND a < 300, 'Success',
a >= 300 AND a < 400, 'Redirection',
a >= 400 AND a < 500, 'Client Error',
a >= 500, 'Server Error'
else 'Unknown'
)
```

Limitation: Overriding existing field is unsupported, following queries throw exceptions with "Reference 'a' is ambiguous"
- `source = table | eval a = 10 | fields a,b,c`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ ANOMALY_SCORE_THRESHOLD: 'ANOMALY_SCORE_THRESHOLD';

// COMPARISON FUNCTION KEYWORDS
CASE: 'CASE';
ELSE: 'ELSE';
IN: 'IN';

// LOGICAL KEYWORDS
Expand Down
5 changes: 5 additions & 0 deletions ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ valueExpression
| left = valueExpression binaryOperator = (PLUS | MINUS) right = valueExpression # binaryArithmetic
| primaryExpression # valueExpressionDefault
| positionFunction # positionFunctionCall
| caseFunction # caseExpr
| LT_PRTHS valueExpression RT_PRTHS # parentheticValueExpr
;

Expand All @@ -333,6 +334,10 @@ booleanExpression
: ISEMPTY LT_PRTHS functionArg RT_PRTHS
;

caseFunction
: CASE LT_PRTHS logicalExpression COMMA valueExpression (COMMA logicalExpression COMMA valueExpression)* (ELSE valueExpression)? RT_PRTHS
;

relevanceExpression
: singleFieldRelevanceFunction
| multiFieldRelevanceFunction
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,8 @@ public T visitIsEmpty(IsEmpty node, C context) {
return visitChildren(node, context);
}

// TODO add case

public T visitWindowFunction(WindowFunction node, C context) {
return visitChildren(node, context);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -611,6 +611,8 @@ public Expression visitKmeans(Kmeans node, CatalystPlanContext context) {

@Override
public Expression visitCase(Case node, CatalystPlanContext context) {
Stack<Expression> initialNameExpressions = new Stack<>();
initialNameExpressions.addAll(context.getNamedParseExpressions());
analyze(node.getElseClause(), context);
Expression elseValue = context.getNamedParseExpressions().pop();
List<Tuple2<Expression, Expression>> whens = new ArrayList<>();
Expand All @@ -633,6 +635,7 @@ public Expression visitCase(Case node, CatalystPlanContext context) {
}
context.retainAllNamedParseExpressions(e -> e);
}
context.setNamedParseExpressions(initialNameExpressions);
return context.getNamedParseExpressions().push(new CaseWhen(seq(whens), Option.apply(elseValue)));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import java.util.stream.Stream;

import static org.opensearch.sql.expression.function.BuiltinFunctionName.EQUAL;
Expand Down Expand Up @@ -199,6 +200,25 @@ public UnresolvedExpression visitBooleanFunctionCall(OpenSearchPPLParser.Boolean
ctx.functionArgs().functionArg());
}

@Override
public UnresolvedExpression visitCaseExpr(OpenSearchPPLParser.CaseExprContext ctx) {
List<When> whens = IntStream.range(0, ctx.caseFunction().logicalExpression().size())
.mapToObj(index -> {
OpenSearchPPLParser.LogicalExpressionContext logicalExpressionContext = ctx.caseFunction().logicalExpression(index);
OpenSearchPPLParser.ValueExpressionContext valueExpressionContext = ctx.caseFunction().valueExpression(index);
UnresolvedExpression condition = visit(logicalExpressionContext);
UnresolvedExpression result = visit(valueExpressionContext);
return new When(condition, result);
})
.collect(Collectors.toList());
UnresolvedExpression elseValue = new Literal(null, DataType.NULL);
if(ctx.caseFunction().valueExpression().size() > ctx.caseFunction().logicalExpression().size()) {
// else value is present
elseValue = visit(ctx.caseFunction().valueExpression(ctx.caseFunction().valueExpression().size() - 1));
}
return new Case(new Literal(true, DataType.BOOLEAN), whens, elseValue);
}

@Override
public UnresolvedExpression visitIsEmptyExpression(OpenSearchPPLParser.IsEmptyExpressionContext ctx) {
Function trimFunction = new Function(TRIM.getName().getFunctionName(), Collections.singletonList(this.visitFunctionArg(ctx.functionArg())));
Expand Down

0 comments on commit 4496532

Please sign in to comment.