Skip to content

Commit

Permalink
Unquote text and identifiers in PPL parsing (#393)
Browse files Browse the repository at this point in the history
* unquote text and identifiers in PPL parsing

Signed-off-by: Sean Kao <[email protected]>

* clean PPL suite comments

Signed-off-by: Sean Kao <[email protected]>

* fix PPL suite typo

Signed-off-by: Sean Kao <[email protected]>

* parameterize test cases

Signed-off-by: Sean Kao <[email protected]>

* add UT for StringUtils

Signed-off-by: Sean Kao <[email protected]>

* use JUnit 4

Signed-off-by: Sean Kao <[email protected]>

---------

Signed-off-by: Sean Kao <[email protected]>
  • Loading branch information
seankao-az authored Jun 27, 2024
1 parent 0c1ec6b commit 9fad78e
Show file tree
Hide file tree
Showing 10 changed files with 260 additions and 110 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,6 @@ class FlintSparkPPLAggregationsITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(Row(36.25))

// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0))
assert(results.sorted.sameElements(expectedResults.sorted))
Expand Down Expand Up @@ -76,7 +75,6 @@ class FlintSparkPPLAggregationsITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(Row(25))

// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Double](_.getAs[Double](0))
assert(results.sorted.sameElements(expectedResults.sorted))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,32 +37,34 @@ class FlintSparkPPLBasicITSuite
}

test("create ppl simple query test") {
val frame = sql(s"""
| source = $testTable
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("Jake", 70, "California", "USA", 2023, 4),
Row("Hello", 30, "New York", "USA", 2023, 4),
Row("John", 25, "Ontario", "Canada", 2023, 4),
Row("Jane", 20, "Quebec", "Canada", 2023, 4))
// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val expectedPlan: LogicalPlan =
Project(
Seq(UnresolvedStar(None)),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
// Compare the two plans
assert(expectedPlan === logicalPlan)
val testTableQuoted = "`spark_catalog`.`default`.`flint_ppl_test`"
Seq(testTable, testTableQuoted).foreach { table =>
val frame = sql(s"""
| source = $table
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
// Define the expected results
val expectedResults: Array[Row] = Array(
Row("Jake", 70, "California", "USA", 2023, 4),
Row("Hello", 30, "New York", "USA", 2023, 4),
Row("John", 25, "Ontario", "Canada", 2023, 4),
Row("Jane", 20, "Quebec", "Canada", 2023, 4))
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
// Define the expected logical plan
val expectedPlan: LogicalPlan =
Project(
Seq(UnresolvedStar(None)),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
// Compare the two plans
assert(expectedPlan === logicalPlan)
}
}

test("create ppl simple query with head (limit) 3 test") {
Expand Down Expand Up @@ -90,7 +92,6 @@ class FlintSparkPPLBasicITSuite
| source = $testTable| sort name | head 2
| """.stripMargin)

// Retrieve the results
// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 2)
Expand Down Expand Up @@ -187,27 +188,29 @@ class FlintSparkPPLBasicITSuite
}

test("create ppl simple query two with fields and head (limit) with sorting test") {
val frame = sql(s"""
| source = $testTable| fields name, age | head 1 | sort age
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 1)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val project = Project(
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
// Define the expected logical plan
val limitPlan: LogicalPlan = Limit(Literal(1), project)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, limitPlan)

val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan);
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
Seq(("name, age", "age"), ("`name`, `age`", "`age`")).foreach {
case (selectFields, sortField) =>
val frame = sql(s"""
| source = $testTable| fields $selectFields | head 1 | sort $sortField
| """.stripMargin)

// Retrieve the results
val results: Array[Row] = frame.collect()
assert(results.length == 1)

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val project = Project(
Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")),
UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))
// Define the expected logical plan
val limitPlan: LogicalPlan = Limit(Literal(1), project)
val sortedPlan: LogicalPlan =
Sort(Seq(SortOrder(UnresolvedAttribute("age"), Ascending)), global = true, limitPlan)

val expectedPlan = Project(Seq(UnresolvedStar(None)), sortedPlan);
// Compare the two plans
assert(compareByString(expectedPlan) === compareByString(logicalPlan))
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class FlintSparkPPLFiltersITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(Row("John", 25))
// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

Expand All @@ -72,7 +71,6 @@ class FlintSparkPPLFiltersITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(Row("John", 25), Row("Jane", 20))
// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

Expand Down Expand Up @@ -182,7 +180,6 @@ class FlintSparkPPLFiltersITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(Row("Jake", 70), Row("Hello", 30))
// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

Expand All @@ -209,7 +206,6 @@ class FlintSparkPPLFiltersITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("John", 25), Row("Jane", 20))
// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))

Expand Down Expand Up @@ -287,7 +283,6 @@ class FlintSparkPPLFiltersITSuite
// Define the expected results
val expectedResults: Array[Row] = Array(Row("Hello", 30), Row("John", 25), Row("Jane", 20))

// Compare the results
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0))
assert(results.sorted.sameElements(expectedResults.sorted))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ class FlintSparkPPLTimeWindowITSuite
override def beforeAll(): Unit = {
super.beforeAll()
// Create test table
// Update table creation
createTimeSeriesTransactionTable(testTable)
}

Expand All @@ -39,16 +38,6 @@ class FlintSparkPPLTimeWindowITSuite
}

test("create ppl query count sales by days window test") {
/*
val dataFrame = spark.read.table(testTable)
val query = dataFrame
.groupBy(
window(
col("transactionDate"), " 1 days")
).agg(sum(col("productsAmount")))
query.show(false)
*/
val frame = sql(s"""
| source = $testTable| stats sum(productsAmount) by span(transactionDate, 1d) as age_date
| """.stripMargin)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.common.utils;

import com.google.common.base.Strings;

import java.util.IllegalFormatException;
import java.util.Locale;

public class StringUtils {
/**
* Unquote Identifier which has " or ' as mark. Strings quoted by ' or " with two of these quotes
* appearing next to each other in the quote acts as an escape<br>
* Example: 'Test''s' will result in 'Test's', similar with those single quotes being replaced
* with double quote. Supports escaping quotes (single/double) and escape characters using the `\`
* characters.
*
* @param text string
* @return An unquoted string whose outer pair of (single/double) quotes have been removed
*/
public static String unquoteText(String text) {
if (text.length() < 2) {
return text;
}

char enclosingQuote = 0;
char firstChar = text.charAt(0);
char lastChar = text.charAt(text.length() - 1);

if (firstChar != lastChar) {
return text;
}

if (firstChar == '`') {
return text.substring(1, text.length() - 1);
}

if (firstChar == lastChar && (firstChar == '\'' || firstChar == '"')) {
enclosingQuote = firstChar;
} else {
return text;
}

char currentChar;
char nextChar;

StringBuilder textSB = new StringBuilder();

// Ignores first and last character as they are the quotes that should be removed
for (int chIndex = 1; chIndex < text.length() - 1; chIndex++) {
currentChar = text.charAt(chIndex);
nextChar = text.charAt(chIndex + 1);

if ((currentChar == '\\' && (nextChar == '"' || nextChar == '\\' || nextChar == '\''))
|| (currentChar == nextChar && currentChar == enclosingQuote)) {
chIndex++;
currentChar = nextChar;
}
textSB.append(currentChar);
}
return textSB.toString();
}

/**
* Unquote Identifier which has ` as mark.
*
* @param identifier identifier that possibly enclosed by backticks
* @return An unquoted string whose outer pair of backticks have been removed
*/
public static String unquoteIdentifier(String identifier) {
if (isQuoted(identifier, "`")) {
return identifier.substring(1, identifier.length() - 1);
} else {
return identifier;
}
}

/**
* Returns a formatted string using the specified format string and arguments, as well as the
* {@link Locale#ROOT} locale.
*
* @param format format string
* @param args arguments referenced by the format specifiers in the format string
* @return A formatted string
* @throws IllegalFormatException If a format string contains an illegal syntax, a format
* specifier that is incompatible with the given arguments, insufficient arguments given the
* format string, or other illegal conditions.
* @see String#format(Locale, String, Object...)
*/
public static String format(final String format, Object... args) {
return String.format(Locale.ROOT, format, args);
}

private static boolean isQuoted(String text, String mark) {
return !Strings.isNullOrEmpty(text) && text.startsWith(mark) && text.endsWith(mark);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import org.opensearch.sql.ast.expression.UnresolvedArgument;
import org.opensearch.sql.ast.expression.UnresolvedExpression;
import org.opensearch.sql.ast.expression.Xor;
import org.opensearch.sql.common.utils.StringUtils;
import org.opensearch.sql.ppl.utils.ArgumentFactory;

import java.util.Arrays;
Expand Down Expand Up @@ -322,7 +323,7 @@ public UnresolvedExpression visitIntervalLiteral(OpenSearchPPLParser.IntervalLit

@Override
public UnresolvedExpression visitStringLiteral(OpenSearchPPLParser.StringLiteralContext ctx) {
return new Literal(ctx.getText(), DataType.STRING);
return new Literal(StringUtils.unquoteText(ctx.getText()), DataType.STRING);
}

@Override
Expand All @@ -349,7 +350,7 @@ public UnresolvedExpression visitBySpanClause(OpenSearchPPLParser.BySpanClauseCo
String name = ctx.spanClause().getText();
return ctx.alias != null
? new Alias(
name, visit(ctx.spanClause()), ctx.alias.getText())
name, visit(ctx.spanClause()), StringUtils.unquoteIdentifier(ctx.alias.getText()))
: new Alias(name, visit(ctx.spanClause()));
}

Expand All @@ -363,6 +364,7 @@ private QualifiedName visitIdentifiers(List<? extends ParserRuleContext> ctx) {
return new QualifiedName(
ctx.stream()
.map(RuleContext::getText)
.map(StringUtils::unquoteIdentifier)
.collect(Collectors.toList()));
}

Expand All @@ -373,18 +375,18 @@ private List<UnresolvedExpression> singleFieldRelevanceArguments(
ImmutableList.Builder<UnresolvedExpression> builder = ImmutableList.builder();
builder.add(
new UnresolvedArgument(
"field", new QualifiedName(ctx.field.getText())));
"field", new QualifiedName(StringUtils.unquoteText(ctx.field.getText()))));
builder.add(
new UnresolvedArgument(
"query", new Literal(ctx.query.getText(), DataType.STRING)));
"query", new Literal(StringUtils.unquoteText(ctx.query.getText()), DataType.STRING)));
ctx.relevanceArg()
.forEach(
v ->
builder.add(
new UnresolvedArgument(
v.relevanceArgName().getText().toLowerCase(),
new Literal(
v.relevanceArgValue().getText(),
StringUtils.unquoteText(v.relevanceArgValue().getText()),
DataType.STRING))));
return builder.build();
}
Expand Down
Loading

0 comments on commit 9fad78e

Please sign in to comment.