From 997b583824661a171e9191a2ec3fd00e3d7e20d7 Mon Sep 17 00:00:00 2001 From: "opensearch-trigger-bot[bot]" <98922864+opensearch-trigger-bot[bot]@users.noreply.github.com> Date: Wed, 11 Sep 2024 13:34:03 -0700 Subject: [PATCH] Ppl patterns command (#627) (#639) * add patterns support & tests * update tests * remove unrelated Dockerfile * sbt format * fix ParseUtils and simplify different pase expressions according to PR comments feedback --------- (cherry picked from commit fd3f82fc16273d47f17aa47e699aeaa8b4c3679d) Signed-off-by: YANGDB Signed-off-by: github-actions[bot] Co-authored-by: github-actions[bot] --- .../ppl/FlintSparkPPLPatternsITSuite.scala | 166 ++++++++++++++++ ppl-spark-integration/README.md | 6 + .../sql/ppl/utils/ParseStrategy.java | 14 +- .../opensearch/sql/ppl/utils/ParseUtils.java | 181 ++++++------------ ...gicalPlanPatternsTranslatorTestSuite.scala | 181 ++++++++++++++++++ 5 files changed, 416 insertions(+), 132 deletions(-) create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLPatternsITSuite.scala create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanPatternsTranslatorTestSuite.scala diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLPatternsITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLPatternsITSuite.scala new file mode 100644 index 000000000..422ef66c3 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLPatternsITSuite.scala @@ -0,0 +1,166 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.{QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, GreaterThan, Literal, NullsLast, RegExpExtract, RegExpReplace, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical._ +import org.apache.spark.sql.streaming.StreamTest + +class FlintSparkPPLPatternsITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedGrokEmailTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test patterns email & host expressions") { + val frame = sql(s""" + | source = $testTable| patterns email | fields email, patterns_field + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("charlie@domain.net", "@."), + Row("david@anotherdomain.com", "@."), + Row("hank@demonstration.com", "@."), + Row("alice@example.com", "@."), + Row("frank@sample.org", "@."), + Row("grace@demo.net", "@."), + Row("jack@sample.net", "@."), + Row("eve@examples.com", "@."), + Row("ivy@examples.com", "@."), + Row("bob@test.org", "@.")) + + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val emailAttribute = UnresolvedAttribute("email") + val patterns_field = UnresolvedAttribute("patterns_field") + val hostExpression = Alias( + RegExpReplace(emailAttribute, Literal("[a-zA-Z0-9]"), Literal("")), + "patterns_field")() + val expectedPlan = Project( + Seq(emailAttribute, patterns_field), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))) + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("test patterns email expressions parsing filter & sort by age") { + val frame = sql(s""" + | source = $testTable| patterns email | where age > 45 | sort - age | fields age, email, patterns_field; + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row(76, "frank@sample.org", "@."), + Row(65, "charlie@domain.net", "@."), + Row(55, "bob@test.org", "@.")) + + // Compare the results + assert(results.sameElements(expectedResults)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val emailAttribute = UnresolvedAttribute("email") + val patterns_fieldAttribute = UnresolvedAttribute("patterns_field") + val ageAttribute = UnresolvedAttribute("age") + val patternExpression = Alias( + RegExpReplace(emailAttribute, Literal("[a-zA-Z0-9]"), Literal("")), + "patterns_field")() + + // Define the corrected expected plan + val expectedPlan = Project( + Seq(ageAttribute, emailAttribute, patterns_fieldAttribute), + Sort( + Seq(SortOrder(ageAttribute, Descending, NullsLast, Seq.empty)), + global = true, + Filter( + GreaterThan(ageAttribute, Literal(45)), + Project( + Seq(emailAttribute, patternExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))))) + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) + } + + test("test patterns email expressions and top count_host ") { + val frame = sql( + "source=spark_catalog.default.flint_ppl_test | patterns new_field='dot_com' pattern='(.com|.net|.org)' email | stats count() by dot_com ") + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row(1L, "charlie@domain"), + Row(1L, "david@anotherdomain"), + Row(1L, "hank@demonstration"), + Row(1L, "alice@example"), + Row(1L, "frank@sample"), + Row(1L, "grace@demo"), + Row(1L, "jack@sample"), + Row(1L, "eve@examples"), + Row(1L, "ivy@examples"), + Row(1L, "bob@test")) + + // Sort both the results and the expected results + implicit val rowOrdering: Ordering[Row] = Ordering.by(r => (r.getLong(0), r.getString(1))) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val messageAttribute = UnresolvedAttribute("email") + val noNumbersAttribute = UnresolvedAttribute("dot_com") + val hostExpression = Alias( + RegExpReplace(messageAttribute, Literal("(.com|.net|.org)"), Literal("")), + "dot_com")() + + // Define the corrected expected plan + val expectedPlan = Project( + Seq(UnresolvedStar(None)), // Matches the '*' in the Project + Aggregate( + Seq(Alias(noNumbersAttribute, "dot_com")()), // Group by 'no_numbers' + Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "count()")(), + Alias(noNumbersAttribute, "dot_com")()), + Project( + Seq(messageAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))))) + + // Compare the logical plans + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } +} diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index 979cb712d..0c34cbbc3 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -329,6 +329,12 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source=accounts | grok street_address '%{NUMBER} %{GREEDYDATA:address}' | fields address ` - `source=logs | grok message '%{COMMONAPACHELOG}' | fields COMMONAPACHELOG, timestamp, response, bytes` +**Patterns** +- `source=accounts | patterns email | fields email, patterns_field ` +- `source=accounts | patterns email | where age > 45 | sort - age | fields email, patterns_field` +- `source=apache | patterns new_field='no_numbers' pattern='[0-9]' message | fields message, no_numbers` +- `source=apache | patterns new_field='no_numbers' pattern='[0-9]' message | stats count() by no_numbers` + _- **Limitation: Overriding existing field is unsupported:**_ - `source=accounts | grok address '%{NUMBER} %{GREEDYDATA:address}' | fields address` diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseStrategy.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseStrategy.java index 45766e588..6cdb2f6b2 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseStrategy.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseStrategy.java @@ -45,26 +45,26 @@ static LogicalPlan visitParseCommand(Parse node, Expression sourceField, ParseMe if(field instanceof AllFields) { for (int i = 0; i < namedGroupCandidates.size(); i++) { namedGroupNumbers.put(namedGroupCandidates.get(i), - ParseUtils.getNamedGroupIndex(parseMethod, pattern, namedGroupCandidates.get(i))); + ParseUtils.getNamedGroupIndex(parseMethod, pattern, namedGroupCandidates.get(i), arguments)); } // in specific field case - match to the namedGroupCandidates group } else for (int i = 0; i < namedGroupCandidates.size(); i++) { if (((Field)field).getField().toString().equals(namedGroupCandidates.get(i))) { namedGroupNumbers.put(namedGroupCandidates.get(i), - ParseUtils.getNamedGroupIndex(parseMethod, pattern, namedGroupCandidates.get(i))); + ParseUtils.getNamedGroupIndex(parseMethod, pattern, namedGroupCandidates.get(i), arguments)); } } }); //list the group numbers of these projected fields // match the regExpExtract group identifier with its number namedGroupNumbers.forEach((group, index) -> { - //first create the regExp - RegExpExtract regExpExtract = new RegExpExtract(sourceField, - org.apache.spark.sql.catalyst.expressions.Literal.create(cleanedPattern, StringType), - org.apache.spark.sql.catalyst.expressions.Literal.create(index + 1, IntegerType)); + //first create the regExp + org.apache.spark.sql.catalyst.expressions.Literal patternLiteral = org.apache.spark.sql.catalyst.expressions.Literal.create(cleanedPattern, StringType); + org.apache.spark.sql.catalyst.expressions.Literal groupIndexLiteral = org.apache.spark.sql.catalyst.expressions.Literal.create(index + 1, IntegerType); + Expression regExp = ParseUtils.getRegExpCommand(parseMethod, sourceField, patternLiteral, groupIndexLiteral); //next Alias the extracted fields context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(regExpExtract, + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(regExp, group, NamedExpression.newExprId(), seq(new java.util.ArrayList()), diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java index 128463df1..a463767f0 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java @@ -6,52 +6,29 @@ package org.opensearch.sql.ppl.utils; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.RegExpExtract; +import org.apache.spark.sql.catalyst.expressions.RegExpReplace; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.ParseMethod; import org.opensearch.sql.common.grok.Grok; import org.opensearch.sql.common.grok.GrokCompiler; -import org.opensearch.sql.common.grok.GrokUtils; import org.opensearch.sql.common.grok.Match; import java.util.ArrayList; -import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.Objects; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.atomic.AtomicReference; import java.util.regex.Matcher; import java.util.regex.Pattern; import java.util.stream.Collectors; +import static org.apache.spark.sql.types.DataTypes.StringType; import static org.opensearch.sql.common.grok.GrokUtils.getGroupPatternName; public class ParseUtils { private static final Pattern GROUP_PATTERN = Pattern.compile("\\(\\?<([a-zA-Z][a-zA-Z0-9]*)>"); private static final String NEW_FIELD_KEY = "new_field"; - /** - * Construct corresponding ParseExpression by {@link ParseMethod}. - * - * @param parseMethod method used to parse - * @param pattern pattern used for parsing - * @param identifier derived field - * @return {@link ParseExpression} - */ - public static ParseExpression createParseExpression( - ParseMethod parseMethod, String pattern, String identifier) { - switch (parseMethod) { - case GROK: - return new GrokExpression(pattern, identifier); - case PATTERNS: - return new PatternsExpression(pattern, identifier); - default: - return new RegexExpression(pattern, identifier); - } - } /** * Get list of derived fields based on parse pattern. @@ -63,16 +40,14 @@ public static List getNamedGroupCandidates( ParseMethod parseMethod, String pattern, Map arguments) { switch (parseMethod) { case REGEX: - return RegexExpression.getNamedGroupCandidates(pattern); + return RegexExpression.getNamedGroupCandidates(pattern, arguments); case GROK: - return GrokExpression.getNamedGroupCandidates(pattern); + return GrokExpression.getNamedGroupCandidates(pattern, arguments); default: - return GrokExpression.getNamedGroupCandidates( - arguments.containsKey(NEW_FIELD_KEY) - ? (String) arguments.get(NEW_FIELD_KEY).getValue() - : null); + return PatternsExpression.getNamedGroupCandidates(pattern, arguments); } } + /** * Get list of derived fields based on parse pattern. * @@ -80,14 +55,14 @@ public static List getNamedGroupCandidates( * @return list of names of the derived fields */ public static int getNamedGroupIndex( - ParseMethod parseMethod, String pattern, String namedGroup) { + ParseMethod parseMethod, String pattern, String namedGroup, Map arguments) { switch (parseMethod) { case REGEX: - return RegexExpression.getNamedGroupIndex(pattern, namedGroup); + return RegexExpression.getNamedGroupIndex(pattern, namedGroup, arguments); case GROK: - return GrokExpression.getNamedGroupIndex(pattern, namedGroup); + return GrokExpression.getNamedGroupIndex(pattern, namedGroup, arguments); default: - return PatternsExpression.getNamedGroupIndex(pattern, namedGroup); + return PatternsExpression.getNamedGroupIndex(pattern, namedGroup, arguments); } } @@ -111,26 +86,28 @@ public static String extractPatterns( } } - public static abstract class ParseExpression { - abstract String parseValue(String value); - } - - public static class RegexExpression extends ParseExpression { - private final Pattern regexPattern; - protected final String identifier; - - public RegexExpression(String patterns, String identifier) { - this.regexPattern = Pattern.compile(patterns); - this.identifier = identifier; + public static Expression getRegExpCommand(ParseMethod parseMethod, Expression sourceField, + org.apache.spark.sql.catalyst.expressions.Literal patternLiteral, + org.apache.spark.sql.catalyst.expressions.Literal groupIndexLiteral) { + switch (parseMethod) { + case REGEX: + return RegexExpression.getRegExpCommand(sourceField, patternLiteral, groupIndexLiteral); + case GROK: + return GrokExpression.getRegExpCommand(sourceField, patternLiteral, groupIndexLiteral); + default: + return PatternsExpression.getRegExpCommand(sourceField, patternLiteral, groupIndexLiteral); } + } + public static class RegexExpression { /** * Get list of derived fields based on parse pattern. * - * @param pattern pattern used for parsing + * @param pattern pattern used for parsing + * @param arguments * @return list of names of the derived fields */ - public static List getNamedGroupCandidates(String pattern) { + public static List getNamedGroupCandidates(String pattern, Map arguments) { ImmutableList.Builder namedGroups = ImmutableList.builder(); Matcher m = GROUP_PATTERN.matcher(pattern); while (m.find()) { @@ -139,21 +116,18 @@ public static List getNamedGroupCandidates(String pattern) { return namedGroups.build(); } - public static int getNamedGroupIndex(String pattern,String groupName) { - List groupCandidates = getNamedGroupCandidates(pattern); + public static int getNamedGroupIndex(String pattern, String groupName, Map arguments) { + List groupCandidates = getNamedGroupCandidates(pattern, arguments); for (int i = 0; i < groupCandidates.size(); i++) { - if(groupCandidates.get(i).equals(groupName)) return i; + if (groupCandidates.get(i).equals(groupName)) return i; } return -1; } - @Override - public String parseValue(String value) { - Matcher matcher = regexPattern.matcher(value); - if (matcher.matches()) { - return matcher.group(identifier); - } - return ""; + public static Expression getRegExpCommand(Expression sourceField, + org.apache.spark.sql.catalyst.expressions.Literal patternLiteral, + org.apache.spark.sql.catalyst.expressions.Literal groupIndexLiteral) { + return new RegExpExtract(sourceField, patternLiteral, groupIndexLiteral); } public static String extractPattern(String patterns, List columns) { @@ -161,51 +135,37 @@ public static String extractPattern(String patterns, List columns) { } } - public static class GrokExpression extends ParseExpression { + public static class GrokExpression { private static final GrokCompiler grokCompiler = GrokCompiler.newInstance(); static { grokCompiler.registerDefaultPatterns(); } - private final Grok grok; - private final String identifier; - - public GrokExpression(String pattern, String identifier) { - this.grok = grokCompiler.compile(pattern); - this.identifier = identifier; - } - - @Override - public String parseValue(String value) { - Match grokMatch = grok.match(value); - Map capture = grokMatch.capture(); - Object match = capture.get(identifier); - if (match != null) { - return match.toString(); - } - return ""; + public static Expression getRegExpCommand(Expression sourceField, org.apache.spark.sql.catalyst.expressions.Literal patternLiteral, org.apache.spark.sql.catalyst.expressions.Literal groupIndexLiteral) { + return new RegExpExtract(sourceField, patternLiteral, groupIndexLiteral); } /** * Get list of derived fields based on parse pattern. * - * @param pattern pattern used for parsing + * @param pattern pattern used for parsing + * @param arguments * @return list of names of the derived fields */ - public static List getNamedGroupCandidates(String pattern) { + public static List getNamedGroupCandidates(String pattern, Map arguments) { Grok grok = grokCompiler.compile(pattern); return grok.namedGroups.stream() .map(grok::getNamedRegexCollectionById) .filter(group -> !group.equals("UNWANTED")) .collect(Collectors.toUnmodifiableList()); } - - public static int getNamedGroupIndex(String pattern,String groupName) { + + public static int getNamedGroupIndex(String pattern, String groupName, Map arguments) { String name = getGroupPatternName(grokCompiler.compile(pattern), groupName); List namedGroups = new ArrayList<>(grokCompiler.compile(pattern).namedGroups); for (int i = 0; i < namedGroups.size(); i++) { - if(namedGroups.get(i).equals(name)) return i; + if (namedGroups.get(i).equals(name)) return i; } return -1; } @@ -216,63 +176,34 @@ public static String extractPattern(final String patterns, List columns) } } - public static class PatternsExpression extends ParseExpression { + public static class PatternsExpression { public static final String DEFAULT_NEW_FIELD = "patterns_field"; + private static final String DEFAULT_IGNORED_PATTERN = "[a-zA-Z0-9]"; - private static final ImmutableSet DEFAULT_IGNORED_CHARS = - ImmutableSet.copyOf( - "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - .chars() - .mapToObj(c -> (char) c) - .toArray(Character[]::new)); - private final boolean useCustomPattern; - private Pattern pattern; - - /** - * PatternsExpression. - * - * @param pattern pattern used for parsing - * @param identifier derived field - */ - public PatternsExpression(String pattern, String identifier) { - useCustomPattern = !pattern.isEmpty(); - if (useCustomPattern) { - this.pattern = Pattern.compile(pattern); - } - } - public static int getNamedGroupIndex(String pattern, String namedGroup) { + public static int getNamedGroupIndex(String pattern, String namedGroup, Map arguments) { return 0; } - @Override - public String parseValue(String value) { - if (useCustomPattern) { - return pattern.matcher(value).replaceAll(""); - } - - char[] chars = value.toCharArray(); - int pos = 0; - for (int i = 0; i < chars.length; i++) { - if (!DEFAULT_IGNORED_CHARS.contains(chars[i])) { - chars[pos++] = chars[i]; - } - } - return new String(chars, 0, pos); + public static Expression getRegExpCommand(Expression sourceField, + org.apache.spark.sql.catalyst.expressions.Literal patternLiteral, + org.apache.spark.sql.catalyst.expressions.Literal groupIndexLiteral) { + return new RegExpReplace(sourceField, patternLiteral, org.apache.spark.sql.catalyst.expressions.Literal.create("", StringType)); } - + /** * Get list of derived fields. * - * @param identifier identifier used to generate the field name + * @param pattern + * @param arguments * @return list of names of the derived fields */ - public static List getNamedGroupCandidates(String identifier) { - return ImmutableList.of(Objects.requireNonNullElse(identifier, DEFAULT_NEW_FIELD)); + public static List getNamedGroupCandidates(String pattern, Map arguments) { + return ImmutableList.of(arguments.containsKey(NEW_FIELD_KEY) ? arguments.get(NEW_FIELD_KEY).toString() : DEFAULT_NEW_FIELD); } public static String extractPattern(String patterns, List columns) { - return patterns; + return patterns != null && !patterns.isEmpty() ? patterns : DEFAULT_IGNORED_PATTERN; } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanPatternsTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanPatternsTranslatorTestSuite.scala new file mode 100644 index 000000000..a26d365d2 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanPatternsTranslatorTestSuite.scala @@ -0,0 +1,181 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Descending, GreaterThan, Literal, NullsLast, RegExpExtract, RegExpReplace, SortOrder} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical._ + +class PPLLogicalPlanPatternsTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test patterns email & host expressions") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=accounts | patterns email | fields email, patterns_field ", + isExplain = false), + context) + + val emailAttribute = UnresolvedAttribute("email") + val patterns_field = UnresolvedAttribute("patterns_field") + val hostExpression = Alias( + RegExpReplace(emailAttribute, Literal("[a-zA-Z0-9]"), Literal("")), + "patterns_field")() + val expectedPlan = Project( + Seq(emailAttribute, patterns_field), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("accounts")))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test( + "test patterns extract punctuations from a raw log field using user defined patterns and a new field") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=apache | patterns new_field='no_numbers' pattern='[0-9]' message | fields message, no_numbers", + false), + context) + + val emailAttribute = UnresolvedAttribute("message") + val patterns_field = UnresolvedAttribute("no_numbers") + val hostExpression = + Alias(RegExpReplace(emailAttribute, Literal("[0-9]"), Literal("")), "no_numbers")() + val expectedPlan = Project( + Seq(emailAttribute, patterns_field), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("apache")))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + + } + + test("test patterns email & host expressions with filter by age and sort by age field") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=accounts | patterns email | where age > 45 | sort - age | fields email, patterns_field", + isExplain = false), + context) + + // Define the expected logical plan + val emailAttribute = UnresolvedAttribute("email") + val patterns_fieldAttribute = UnresolvedAttribute("patterns_field") + val ageAttribute = UnresolvedAttribute("age") + val hostExpression = Alias( + RegExpReplace(emailAttribute, Literal("[a-zA-Z0-9]"), Literal("")), + "patterns_field")() + + // Define the corrected expected plan + val expectedPlan = Project( + Seq(emailAttribute, patterns_fieldAttribute), + Sort( + Seq(SortOrder(ageAttribute, Descending, NullsLast, Seq.empty)), + global = true, + Filter( + GreaterThan(ageAttribute, Literal(45)), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("accounts")))))) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + + test("test patterns email expressions and group by count host ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=apache | patterns new_field='no_numbers' pattern='[0-9]' message | stats count() by no_numbers", + false), + context) + + val messageAttribute = UnresolvedAttribute("message") + val noNumbersAttribute = UnresolvedAttribute("no_numbers") + val hostExpression = + Alias(RegExpReplace(messageAttribute, Literal("[0-9]"), Literal("")), "no_numbers")() + + // Define the corrected expected plan + val expectedPlan = Project( + Seq(UnresolvedStar(None)), // Matches the '*' in the Project + Aggregate( + Seq(Alias(noNumbersAttribute, "no_numbers")()), // Group by 'no_numbers' + Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "count()")(), + Alias(noNumbersAttribute, "no_numbers")()), + Project( + Seq(messageAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("apache"))))) + + // Compare the logical plans + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test patterns email expressions and top count_host ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan( + pplParser, + "source=apache | patterns new_field='no_numbers' pattern='[0-9]' message | top 1 no_numbers", + false), + context) + + val messageAttribute = UnresolvedAttribute("message") + val noNumbersAttribute = UnresolvedAttribute("no_numbers") + val hostExpression = + Alias(RegExpReplace(messageAttribute, Literal("[0-9]"), Literal("")), "no_numbers")() + + val sortedPlan = Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(noNumbersAttribute), isDistinct = false), + "count_no_numbers")(), + Descending, + NullsLast, + Seq.empty)), + global = true, + Aggregate( + Seq(noNumbersAttribute), + Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(noNumbersAttribute), isDistinct = false), + "count_no_numbers")(), + noNumbersAttribute), + Project( + Seq(messageAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("apache"))))) + // Define the corrected expected plan + val expectedPlan = Project( + Seq(UnresolvedStar(None)), // Matches the '*' in the Project + GlobalLimit(Literal(1), LocalLimit(Literal(1), sortedPlan))) + // Compare the logical plans + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +}