From 16d7cc0a2521d4f887cc55bfabc2ceb0cccd636c Mon Sep 17 00:00:00 2001 From: YANGDB Date: Tue, 6 Aug 2024 17:17:04 -0700 Subject: [PATCH 1/9] add parse regexp command for PPL Signed-off-by: YANGDB --- .../flint/spark/FlintSparkSuite.scala | 38 +++++++++++ .../spark/ppl/FlintSparkPPLParseITSuite.scala | 65 ++++++++++++++++++ .../sql/ppl/CatalystQueryPlanVisitor.java | 29 ++++++++ .../sql/ppl/parser/AstExpressionBuilder.java | 2 + ...LLogicalPlanParseTranslatorTestSuite.scala | 68 +++++++++++++++++++ 5 files changed, 202 insertions(+) create mode 100644 integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala create mode 100644 ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 7e0b68376..bc5fe7999 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -98,6 +98,44 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit } } + protected def createPartitionedGrokEmailTable(testTable: String): Unit = { + spark.sql( + s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | email STRING, + | street_address STRING + | ) + | USING $tableType $tableOptions + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + + val data = Seq( + ("Alice", "alice@example.com", "123 Main St, Seattle", 2023, 4), + ("Bob", "bob@test.org", "456 Elm St, Portland", 2023, 5), + ("Charlie", "charlie@domain.net", "789 Pine St, San Francisco", 2023, 4), + ("David", "david@anotherdomain.com", "101 Maple St, New York", 2023, 5), + ("Eve", "eve@examples.com", "202 Oak St, Boston", 2023, 4), + ("Frank", "frank@sample.org", "303 Cedar St, Austin", 2023, 5), + ("Grace", "grace@demo.net", "404 Birch St, Chicago", 2023, 4), + ("Hank", "hank@demonstration.com", "505 Spruce St, Miami", 2023, 5), + ("Ivy", "ivy@examples.org", "606 Fir St, Denver", 2023, 4), + ("Jack", "jack@sample.net", "707 Ash St, Seattle", 2023, 5) + ) + + data.foreach { case (name, email, street_address, year, month) => + spark.sql( + s""" + | INSERT INTO $testTable + | PARTITION (year=$year, month=$month) + | VALUES ('$name', '$email', '$street_address') + | """.stripMargin) + } + } protected def createPartitionedAddressTable(testTable: String): Unit = { sql(s""" | CREATE TABLE $testTable diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala new file mode 100644 index 000000000..d39bdc522 --- /dev/null +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project, Sort} +import org.apache.spark.sql.streaming.StreamTest +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +class FlintSparkPPLParseITSuite + extends QueryTest + with LogicalPlanTestUtils + with FlintPPLSuite + with StreamTest { + + /** Test table and index name */ + private val testTable = "spark_catalog.default.flint_ppl_test" + + override def beforeAll(): Unit = { + super.beforeAll() + + // Create test table + createPartitionedGrokEmailTable(testTable) + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Stop all streaming jobs if any + spark.streams.active.foreach { job => + job.stop() + job.awaitTermination() + } + } + + test("test parse email expressions parsing") { + val frame = sql(s""" + | source = $testTable | parse email '.+@(?.+)' | fields email, host ; + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = + Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + // Compare the results + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val fieldsProjectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) + val evalProjectList = Seq(UnresolvedStar(None), Alias(Literal(1), "col")()) + val expectedPlan = Project(fieldsProjectList, Project(evalProjectList, table)) + // Compare the two plans + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index fd8d81e5c..2f601e56f 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -11,6 +11,7 @@ import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; +import org.apache.spark.sql.catalyst.expressions.RegExpExtract; import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; import org.apache.spark.sql.catalyst.plans.logical.Limit; @@ -50,6 +51,7 @@ import org.opensearch.sql.ast.tree.Filter; import org.opensearch.sql.ast.tree.Head; import org.opensearch.sql.ast.tree.Kmeans; +import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.ast.tree.Project; import org.opensearch.sql.ast.tree.RareTopN; import org.opensearch.sql.ast.tree.Relation; @@ -224,6 +226,33 @@ private Expression visitExpression(UnresolvedExpression expression, CatalystPlan return expressionAnalyzer.analyze(expression, context); } + @Override + public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + List aliases = new ArrayList<>(); + switch (node.getParseMethod()) { + case GROK: + throw new IllegalStateException("Not Supported operation : GROK"); + case PATTERNS: + throw new IllegalStateException("Not Supported operation : PATTERNS"); + case REGEX: + //todo + } + UnresolvedExpression sourceField = node.getSourceField(); + Literal pattern = node.getPattern(); + Alias alias = new Alias(sourceField.toString(), let.getExpression()); + aliases.add(alias); + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + List expressionList = visitExpressionList(aliases, context); + Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + // build the plan with the projection step + child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + return child; + } + @Override public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 71abb329f..47b364211 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -26,12 +26,14 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.ParseMethod; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.SpanUnit; import org.opensearch.sql.ast.expression.UnresolvedArgument; import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.tree.Parse; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.utils.ArgumentFactory; diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala new file mode 100644 index 000000000..469841f24 --- /dev/null +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.flint.spark.ppl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.plans.PlanTest +import org.apache.spark.sql.catalyst.plans.logical.Project +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.scalatest.matchers.should.Matchers + +class PPLLogicalPlanParseTranslatorTestSuite + extends SparkFunSuite + with PlanTest + with LogicalPlanTestUtils + with Matchers { + + private val planTransformer = new CatalystQueryPlanVisitor() + private val pplParser = new PPLSyntaxParser() + + test("test parse email & host expressions") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | parse email '.+@(?.+)' | fields email, host", false), + context) + val evalProjectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(1), "b")()) + val expectedPlan = Project( + seq(UnresolvedAttribute("c")), + Project(evalProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test parse email expression") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | parse email '.+@(?.+)' | fields email", false), + context) + val evalProjectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(1), "b")()) + val expectedPlan = Project( + seq(UnresolvedAttribute("c")), + Project(evalProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test parse email expression, generate new host field and eval result") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | parse email '.+@(?.+)' | eval eval_result=1 | fields host, eval_result", false), + context) + val evalProjectList: Seq[NamedExpression] = + Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(1), "b")()) + val expectedPlan = Project( + seq(UnresolvedAttribute("c")), + Project(evalProjectList, UnresolvedRelation(Seq("t")))) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } +} From c54a47198bf8d6cee5ab5b69ab6e6311a8f44a25 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Tue, 20 Aug 2024 14:49:46 -0700 Subject: [PATCH 2/9] add parse code & classes Signed-off-by: YANGDB Signed-off-by: YANGDB --- .../sql/expression/parse/GrokExpression.java | 73 +++++++++++++++++ .../sql/expression/parse/ParseExpression.java | 75 ++++++++++++++++++ .../expression/parse/PatternsExpression.java | 79 +++++++++++++++++++ .../sql/expression/parse/RegexExpression.java | 65 +++++++++++++++ .../sql/ppl/CatalystPlanContext.java | 13 +++ .../sql/ppl/CatalystQueryPlanVisitor.java | 47 ++++++----- .../opensearch/sql/ppl/utils/ParseUtils.java | 68 ++++++++++++++++ 7 files changed, 402 insertions(+), 18 deletions(-) create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/GrokExpression.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/ParseExpression.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/PatternsExpression.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/RegexExpression.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/GrokExpression.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/GrokExpression.java new file mode 100644 index 000000000..4cf79cbae --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/GrokExpression.java @@ -0,0 +1,73 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.parse; + +import lombok.EqualsAndHashCode; +import lombok.ToString; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.sql.common.grok.Grok; +import org.opensearch.sql.common.grok.GrokCompiler; +import org.opensearch.sql.common.grok.Match; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.Expression; + +import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; + +/** GrokExpression with grok patterns. */ +public class GrokExpression extends ParseExpression { + private static final Logger log = LogManager.getLogger(GrokExpression.class); + private static final GrokCompiler grokCompiler = GrokCompiler.newInstance(); + + static { + grokCompiler.registerDefaultPatterns(); + } + + @EqualsAndHashCode.Exclude private final Grok grok; + + /** + * GrokExpression. + * + * @param sourceField source text field + * @param pattern pattern used for parsing + * @param identifier derived field + */ + public GrokExpression(Expression sourceField, Expression pattern, Expression identifier) { + super("grok", sourceField, pattern, identifier); + this.grok = grokCompiler.compile(pattern.valueOf().stringValue()); + } + + @Override + ExprValue parseValue(ExprValue value) throws ExpressionEvaluationException { + String rawString = value.stringValue(); + Match grokMatch = grok.match(rawString); + Map capture = grokMatch.capture(); + Object match = capture.get(identifierStr); + if (match != null) { + return new ExprStringValue(match.toString()); + } + log.debug("failed to extract pattern {} from input ***", grok.getOriginalGrokPattern()); + return new ExprStringValue(""); + } + + /** + * Get list of derived fields based on parse pattern. + * + * @param pattern pattern used for parsing + * @return list of names of the derived fields + */ + public static List getNamedGroupCandidates(String pattern) { + Grok grok = grokCompiler.compile(pattern); + return grok.namedGroups.stream() + .map(grok::getNamedRegexCollectionById) + .filter(group -> !group.equals("UNWANTED")) + .collect(Collectors.toUnmodifiableList()); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/ParseExpression.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/ParseExpression.java new file mode 100644 index 000000000..6e2456ecc --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/ParseExpression.java @@ -0,0 +1,75 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.parse; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.data.type.ExprType; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.exception.SemanticCheckException; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.ExpressionNodeVisitor; +import org.opensearch.sql.expression.FunctionExpression; +import org.opensearch.sql.expression.env.Environment; +import org.opensearch.sql.expression.function.FunctionName; + +/** ParseExpression. */ +@EqualsAndHashCode(callSuper = false) +@ToString +public abstract class ParseExpression extends FunctionExpression { + @Getter protected final Expression sourceField; + protected final Expression pattern; + @Getter protected final Expression identifier; + protected final String identifierStr; + + /** + * ParseExpression. + * + * @param functionName name of function expression + * @param sourceField source text field + * @param pattern pattern used for parsing + * @param identifier derived field + */ + public ParseExpression( + String functionName, Expression sourceField, Expression pattern, Expression identifier) { + super(FunctionName.of(functionName), ImmutableList.of(sourceField, pattern, identifier)); + this.sourceField = sourceField; + this.pattern = pattern; + this.identifier = identifier; + this.identifierStr = identifier.valueOf().stringValue(); + } + + @Override + public ExprValue valueOf(Environment valueEnv) { + ExprValue value = valueEnv.resolve(sourceField); + if (value.isNull() || value.isMissing()) { + return ExprValueUtils.nullValue(); + } + try { + return parseValue(value); + } catch (ExpressionEvaluationException e) { + throw new SemanticCheckException( + String.format("failed to parse field \"%s\" with type [%s]", sourceField, value.type())); + } + } + + @Override + public ExprType type() { + return ExprCoreType.STRING; + } + + @Override + public T accept(ExpressionNodeVisitor visitor, C context) { + return visitor.visitParse(this, context); + } + + abstract ExprValue parseValue(ExprValue value) throws ExpressionEvaluationException; +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/PatternsExpression.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/PatternsExpression.java new file mode 100644 index 000000000..fedeeef7e --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/PatternsExpression.java @@ -0,0 +1,79 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.parse; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import lombok.EqualsAndHashCode; +import lombok.ToString; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.Expression; + +import java.util.List; +import java.util.Objects; +import java.util.regex.Pattern; + +/** PatternsExpression with regex filter. */ +@EqualsAndHashCode(callSuper = true) +@ToString +public class PatternsExpression extends ParseExpression { + /** Default name of the derived field. */ + public static final String DEFAULT_NEW_FIELD = "patterns_field"; + + private static final ImmutableSet DEFAULT_IGNORED_CHARS = + ImmutableSet.copyOf( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + .chars() + .mapToObj(c -> (char) c) + .toArray(Character[]::new)); + private final boolean useCustomPattern; + @EqualsAndHashCode.Exclude private Pattern pattern; + + /** + * PatternsExpression. + * + * @param sourceField source text field + * @param pattern pattern used for parsing + * @param identifier derived field + */ + public PatternsExpression(Expression sourceField, Expression pattern, Expression identifier) { + super("patterns", sourceField, pattern, identifier); + String patternStr = pattern.valueOf().stringValue(); + useCustomPattern = !patternStr.isEmpty(); + if (useCustomPattern) { + this.pattern = Pattern.compile(patternStr); + } + } + + @Override + ExprValue parseValue(ExprValue value) throws ExpressionEvaluationException { + String rawString = value.stringValue(); + if (useCustomPattern) { + return new ExprStringValue(pattern.matcher(rawString).replaceAll("")); + } + + char[] chars = rawString.toCharArray(); + int pos = 0; + for (int i = 0; i < chars.length; i++) { + if (!DEFAULT_IGNORED_CHARS.contains(chars[i])) { + chars[pos++] = chars[i]; + } + } + return new ExprStringValue(new String(chars, 0, pos)); + } + + /** + * Get list of derived fields. + * + * @param identifier identifier used to generate the field name + * @return list of names of the derived fields + */ + public static List getNamedGroupCandidates(String identifier) { + return ImmutableList.of(Objects.requireNonNullElse(identifier, DEFAULT_NEW_FIELD)); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/RegexExpression.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/RegexExpression.java new file mode 100644 index 000000000..aab51cd4b --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/RegexExpression.java @@ -0,0 +1,65 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.parse; + +import com.google.common.collect.ImmutableList; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.ToString; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.sql.data.model.ExprStringValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.Expression; + +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** RegexExpression with regex and named capture group. */ +public class RegexExpression extends ParseExpression { + private static final Logger log = LogManager.getLogger(RegexExpression.class); + private static final Pattern GROUP_PATTERN = Pattern.compile("\\(\\?<([a-zA-Z][a-zA-Z0-9]*)>"); + private final Pattern regexPattern; + + /** + * RegexExpression. + * + * @param sourceField source text field + * @param pattern pattern used for parsing + * @param identifier derived field + */ + public RegexExpression(Expression sourceField, Expression pattern, Expression identifier) { + super("regex", sourceField, pattern, identifier); + this.regexPattern = Pattern.compile(pattern.valueOf().stringValue()); + } + + @Override + public parseValue(String value) throws ExpressionEvaluationException { + Matcher matcher = regexPattern.matcher(value); + if (matcher.matches()) { + return new ExprStringValue(matcher.group(identifierStr)); + } + log.debug("failed to extract pattern {} from input ***", regexPattern.pattern()); + return new ExprStringValue(""); + } + + /** + * Get list of derived fields based on parse pattern. + * + * @param pattern pattern used for parsing + * @return list of names of the derived fields + */ + public static List getNamedGroupCandidates(String pattern) { + ImmutableList.Builder namedGroups = ImmutableList.builder(); + Matcher m = GROUP_PATTERN.matcher(pattern); + while (m.find()) { + namedGroups.add(m.group(1)); + } + return namedGroups.build(); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index 66ed765a3..ad7e615e2 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -6,9 +6,13 @@ package org.opensearch.sql.ppl; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; +import org.apache.spark.sql.catalyst.expressions.AttributeReference; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.catalyst.plans.logical.Union; +import org.apache.spark.sql.types.Metadata; +import org.opensearch.sql.data.type.ExprType; import scala.collection.Iterator; import scala.collection.Seq; @@ -81,6 +85,15 @@ public Stack getGroupingParseExpressions() { return groupingParseExpressions; } + /** + * define new field + * @param symbol + * @return + */ + public LogicalPlan define(Expression symbol) { + namedParseExpressions.push(symbol); + return getPlan(); + } /** * append plan with evolving plans branches * diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index f361f53d1..881051f15 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -9,21 +9,20 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; +import org.apache.spark.sql.catalyst.expressions.AttributeReference; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; -import org.apache.spark.sql.catalyst.expressions.RegExpExtract; import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; import org.apache.spark.sql.catalyst.plans.logical.Deduplicate; -import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; import org.apache.spark.sql.catalyst.plans.logical.Limit; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.execution.command.DescribeTableCommand; import org.apache.spark.sql.catalyst.plans.logical.Union; -import org.apache.spark.sql.execution.command.DescribeTableCommand; import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.util.CaseInsensitiveStringMap; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.expression.AggregateFunction; @@ -43,6 +42,7 @@ import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.Not; import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.ParseMethod; import org.opensearch.sql.ast.expression.QualifiedName; import org.opensearch.sql.ast.expression.Span; import org.opensearch.sql.ast.expression.UnresolvedExpression; @@ -64,9 +64,11 @@ import org.opensearch.sql.ast.tree.RareTopN; import org.opensearch.sql.ast.tree.Relation; import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.expression.parse.ParseExpression; import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; +import org.opensearch.sql.ppl.utils.ParseUtils; import org.opensearch.sql.ppl.utils.SortUtils; import scala.Option; import scala.Option$; @@ -74,6 +76,7 @@ import java.util.ArrayList; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; import java.util.function.BiFunction; @@ -258,6 +261,11 @@ private Expression visitExpression(UnresolvedExpression expression, CatalystPlan @Override public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); + Expression sourceField = visitExpression(node.getSourceField(), context); + ParseMethod parseMethod = node.getParseMethod(); + java.util.Map arguments = node.getArguments(); + String pattern = (String) node.getPattern().getValue(); + List aliases = new ArrayList<>(); switch (node.getParseMethod()) { case GROK: @@ -265,21 +273,24 @@ public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { case PATTERNS: throw new IllegalStateException("Not Supported operation : PATTERNS"); case REGEX: - //todo - } - UnresolvedExpression sourceField = node.getSourceField(); - Literal pattern = node.getPattern(); - Alias alias = new Alias(sourceField.toString(), let.getExpression()); - aliases.add(alias); - if (context.getNamedParseExpressions().isEmpty()) { - // Create an UnresolvedStar for all-fields projection - context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); - } - List expressionList = visitExpressionList(aliases, context); - Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); - // build the plan with the projection step - child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); - return child; + return visitParseCommand(node, sourceField, parseMethod, arguments, pattern, context); + default: + throw new IllegalArgumentException("Invalid parse command name: " + node.getParseMethod() + + " Syntax: [parse ] "); + + } + } + + private LogicalPlan visitParseCommand(Parse node, Expression sourceField, ParseMethod parseMethod, Map arguments, String pattern, CatalystPlanContext context) { + ParseUtils.getNamedGroupCandidates(parseMethod, pattern, arguments) + .forEach( + group -> { + ParseExpression expr = + ParseUtils.createParseExpression( + parseMethod, sourceField, pattern, group); + context.define(new AttributeReference(group, expr.dataType(), true, Metadata.empty(), NamedExpression.newExprId(), seq(emptyList()))new Ngroup, expr.type()); + context.getNamedParseExpressions().add(new NamedExpression(group, expr)); + }); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java new file mode 100644 index 000000000..b907050a4 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java @@ -0,0 +1,68 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import com.google.common.collect.ImmutableMap; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.ParseMethod; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.parse.GrokExpression; +import org.opensearch.sql.expression.parse.ParseExpression; +import org.opensearch.sql.expression.parse.PatternsExpression; +import org.opensearch.sql.expression.parse.RegexExpression; + +import java.util.List; +import java.util.Map; + +/** Utils for {@link ParseExpression}. */ +public class ParseUtils { + private static final String NEW_FIELD_KEY = "new_field"; + private static final Map FACTORY_MAP = + ImmutableMap.of( + ParseMethod.REGEX, RegexExpression::new, + ParseMethod.GROK, GrokExpression::new, + ParseMethod.PATTERNS, PatternsExpression::new); + + /** + * Construct corresponding ParseExpression by {@link ParseMethod}. + * + * @param parseMethod method used to parse + * @param sourceField source text field + * @param pattern pattern used for parsing + * @param identifier derived field + * @return {@link ParseExpression} + */ + public static ParseExpression createParseExpression( + ParseMethod parseMethod, Expression sourceField, Expression pattern, Expression identifier) { + return FACTORY_MAP.get(parseMethod).initialize(sourceField, pattern, identifier); + } + + /** + * Get list of derived fields based on parse pattern. + * + * @param pattern pattern used for parsing + * @return list of names of the derived fields + */ + public static List getNamedGroupCandidates( + ParseMethod parseMethod, String pattern, Map arguments) { + switch (parseMethod) { + case REGEX: + return RegexExpression.getNamedGroupCandidates(pattern); + case GROK: + return GrokExpression.getNamedGroupCandidates(pattern); + default: + return PatternsExpression.getNamedGroupCandidates( + arguments.containsKey(NEW_FIELD_KEY) + ? (String) arguments.get(NEW_FIELD_KEY).getValue() + : null); + } + } + + private interface ParseExpressionFactory { + ParseExpression initialize( + Expression sourceField, Expression expression, Expression identifier); + } +} From 1b6bebea8933c86d6c035a2a5432eeae2fe79d92 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Wed, 21 Aug 2024 12:00:12 -0700 Subject: [PATCH 3/9] add parse / grok / patterns command Signed-off-by: YANGDB Signed-off-by: YANGDB --- .../opensearch/sql/common/grok/Converter.java | 165 ++++++++++++ .../org/opensearch/sql/common/grok/Grok.java | 171 +++++++++++++ .../sql/common/grok/GrokCompiler.java | 199 +++++++++++++++ .../opensearch/sql/common/grok/GrokUtils.java | 59 +++++ .../org/opensearch/sql/common/grok/Match.java | 241 ++++++++++++++++++ .../common/grok/exception/GrokException.java | 50 ++++ .../sql/expression/parse/GrokExpression.java | 73 ------ .../sql/expression/parse/ParseExpression.java | 75 ------ .../expression/parse/PatternsExpression.java | 79 ------ .../sql/expression/parse/RegexExpression.java | 65 ----- .../sql/ppl/CatalystQueryPlanVisitor.java | 20 +- .../opensearch/sql/ppl/utils/ParseUtils.java | 167 ++++++++++-- 12 files changed, 1044 insertions(+), 320 deletions(-) create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Converter.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Grok.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokCompiler.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokUtils.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Match.java create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/exception/GrokException.java delete mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/GrokExpression.java delete mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/ParseExpression.java delete mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/PatternsExpression.java delete mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/RegexExpression.java diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Converter.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Converter.java new file mode 100644 index 000000000..ddd3a2bbb --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Converter.java @@ -0,0 +1,165 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.grok; + +import java.time.Instant; +import java.time.LocalDate; +import java.time.LocalDateTime; +import java.time.OffsetDateTime; +import java.time.ZoneId; +import java.time.ZoneOffset; +import java.time.ZonedDateTime; +import java.time.format.DateTimeFormatter; +import java.time.temporal.TemporalAccessor; +import java.util.AbstractMap; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.function.Function; +import java.util.regex.Pattern; +import java.util.stream.Collectors; + +/** Convert String argument to the right type. */ +public class Converter { + + public enum Type { + BYTE(Byte::valueOf), + BOOLEAN(Boolean::valueOf), + SHORT(Short::valueOf), + INT(Integer::valueOf, "integer"), + LONG(Long::valueOf), + FLOAT(Float::valueOf), + DOUBLE(Double::valueOf), + DATETIME(new DateConverter(), "date"), + STRING(v -> v, "text"); + + public final IConverter converter; + public final List aliases; + + Type(IConverter converter, String... aliases) { + this.converter = converter; + this.aliases = Arrays.asList(aliases); + } + } + + private static final Pattern SPLITTER = Pattern.compile("[:;]"); + + private static final Map TYPES = + Arrays.stream(Type.values()).collect(Collectors.toMap(t -> t.name().toLowerCase(), t -> t)); + + private static final Map TYPE_ALIASES = + Arrays.stream(Type.values()) + .flatMap( + type -> + type.aliases.stream().map(alias -> new AbstractMap.SimpleEntry<>(alias, type))) + .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + + private static Type getType(String key) { + key = key.toLowerCase(); + Type type = TYPES.getOrDefault(key, TYPE_ALIASES.get(key)); + if (type == null) { + throw new IllegalArgumentException("Invalid data type :" + key); + } + return type; + } + + /** getConverters. */ + public static Map> getConverters( + Collection groupNames, Object... params) { + return groupNames.stream() + .filter(Converter::containsDelimiter) + .collect( + Collectors.toMap( + Function.identity(), + key -> { + String[] list = splitGrokPattern(key); + IConverter converter = getType(list[1]).converter; + if (list.length == 3) { + converter = converter.newConverter(list[2], params); + } + return converter; + })); + } + + /** getGroupTypes. */ + public static Map getGroupTypes(Collection groupNames) { + return groupNames.stream() + .filter(Converter::containsDelimiter) + .map(Converter::splitGrokPattern) + .collect(Collectors.toMap(l -> l[0], l -> getType(l[1]))); + } + + public static String extractKey(String key) { + return splitGrokPattern(key)[0]; + } + + private static boolean containsDelimiter(String string) { + return string.indexOf(':') >= 0 || string.indexOf(';') >= 0; + } + + private static String[] splitGrokPattern(String string) { + return SPLITTER.split(string, 3); + } + + interface IConverter { + + T convert(String value); + + default IConverter newConverter(String param, Object... params) { + return this; + } + } + + static class DateConverter implements IConverter { + + private final DateTimeFormatter formatter; + private final ZoneId timeZone; + + public DateConverter() { + this.formatter = DateTimeFormatter.ISO_DATE_TIME; + this.timeZone = ZoneOffset.UTC; + } + + private DateConverter(DateTimeFormatter formatter, ZoneId timeZone) { + this.formatter = formatter; + this.timeZone = timeZone; + } + + @Override + public Instant convert(String value) { + TemporalAccessor dt = + formatter.parseBest( + value.trim(), + ZonedDateTime::from, + LocalDateTime::from, + OffsetDateTime::from, + Instant::from, + LocalDate::from); + if (dt instanceof ZonedDateTime) { + return ((ZonedDateTime) dt).toInstant(); + } else if (dt instanceof LocalDateTime) { + return ((LocalDateTime) dt).atZone(timeZone).toInstant(); + } else if (dt instanceof OffsetDateTime) { + return ((OffsetDateTime) dt).atZoneSameInstant(timeZone).toInstant(); + } else if (dt instanceof Instant) { + return ((Instant) dt); + } else if (dt instanceof LocalDate) { + return ((LocalDate) dt).atStartOfDay(timeZone).toInstant(); + } else { + return null; + } + } + + @Override + public DateConverter newConverter(String param, Object... params) { + if (!(params.length == 1 && params[0] instanceof ZoneId)) { + throw new IllegalArgumentException("Invalid parameters"); + } + return new DateConverter(DateTimeFormatter.ofPattern(param), (ZoneId) params[0]); + } + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Grok.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Grok.java new file mode 100644 index 000000000..e0c37af99 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Grok.java @@ -0,0 +1,171 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.grok; + +import org.opensearch.sql.common.grok.Converter.IConverter; + +import java.io.Serializable; +import java.time.ZoneId; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * {@code Grok} parse arbitrary text and structure it.
+ * {@code Grok} is simple API that allows you to easily parse logs and other files (single line). + * With {@code Grok}, you can turn unstructured log and event data into structured data. + * + * @since 0.0.1 + */ +public class Grok implements Serializable { + /** Named regex of the originalGrokPattern. */ + private final String namedRegex; + + /** + * Map of the named regex of the originalGrokPattern with id = namedregexid and value = + * namedregex. + */ + private final Map namedRegexCollection; + + /** Original {@code Grok} pattern (expl: %{IP}). */ + private final String originalGrokPattern; + + /** Pattern of the namedRegex. */ + private final Pattern compiledNamedRegex; + + /** {@code Grok} patterns definition. */ + private final Map grokPatternDefinition; + + public final Set namedGroups; + + public final Map groupTypes; + + public final Map> converters; + + /** only use in grok discovery. */ + private String savedPattern = ""; + + /** Grok. */ + public Grok( + String pattern, + String namedRegex, + Map namedRegexCollection, + Map patternDefinitions, + ZoneId defaultTimeZone) { + this.originalGrokPattern = pattern; + this.namedRegex = namedRegex; + this.compiledNamedRegex = Pattern.compile(namedRegex); + this.namedRegexCollection = namedRegexCollection; + this.namedGroups = GrokUtils.getNameGroups(namedRegex); + this.groupTypes = Converter.getGroupTypes(namedRegexCollection.values()); + this.converters = Converter.getConverters(namedRegexCollection.values(), defaultTimeZone); + this.grokPatternDefinition = patternDefinitions; + } + + public String getSaved_pattern() { + return savedPattern; + } + + public void setSaved_pattern(String savedpattern) { + this.savedPattern = savedpattern; + } + + /** + * Get the current map of {@code Grok} pattern. + * + * @return Patterns (name, regular expression) + */ + public Map getPatterns() { + return grokPatternDefinition; + } + + /** + * Get the named regex from the {@code Grok} pattern.
+ * + * @return named regex + */ + public String getNamedRegex() { + return namedRegex; + } + + /** + * Original grok pattern used to compile to the named regex. + * + * @return String Original Grok pattern + */ + public String getOriginalGrokPattern() { + return originalGrokPattern; + } + + /** + * Get the named regex from the given id. + * + * @param id : named regex id + * @return String of the named regex + */ + public String getNamedRegexCollectionById(String id) { + return namedRegexCollection.get(id); + } + + /** + * Get the full collection of the named regex. + * + * @return named RegexCollection + */ + public Map getNamedRegexCollection() { + return namedRegexCollection; + } + + /** + * Match the given log with the named regex. And return the json representation of the + * matched element + * + * @param log : log to match + * @return map containing matches + */ + public Map capture(String log) { + Match match = match(log); + return match.capture(); + } + + /** + * Match the given list of log with the named regex and return the list of json + * representation of the matched elements. + * + * @param logs : list of log + * @return list of maps containing matches + */ + public ArrayList> capture(List logs) { + final ArrayList> matched = new ArrayList<>(); + for (String log : logs) { + matched.add(capture(log)); + } + return matched; + } + + /** + * Match the given text with the named regex {@code Grok} will extract data from the + * string and get an extence of {@link Match}. + * + * @param text : Single line of log + * @return Grok Match + */ + public Match match(CharSequence text) { + if (compiledNamedRegex == null || text == null) { + return Match.EMPTY; + } + + Matcher matcher = compiledNamedRegex.matcher(text); + if (matcher.find()) { + return new Match(text, this, matcher, matcher.start(0), matcher.end(0)); + } + + return Match.EMPTY; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokCompiler.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokCompiler.java new file mode 100644 index 000000000..7d51038cd --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokCompiler.java @@ -0,0 +1,199 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.grok; + +import org.apache.commons.lang3.StringUtils; +import org.opensearch.sql.common.grok.exception.GrokException; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.Reader; +import java.io.Serializable; +import java.nio.charset.Charset; +import java.nio.charset.StandardCharsets; +import java.time.ZoneId; +import java.time.ZoneOffset; +import java.util.HashMap; +import java.util.Map; +import java.util.Objects; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static java.lang.String.format; + +public class GrokCompiler implements Serializable { + + // We don't want \n and commented line + private static final Pattern patternLinePattern = Pattern.compile("^([A-z0-9_]+)\\s+(.*)$"); + + /** {@code Grok} patterns definitions. */ + private final Map grokPatternDefinitions = new HashMap<>(); + + private GrokCompiler() {} + + public static GrokCompiler newInstance() { + return new GrokCompiler(); + } + + public Map getPatternDefinitions() { + return grokPatternDefinitions; + } + + /** + * Registers a new pattern definition. + * + * @param name : Pattern Name + * @param pattern : Regular expression Or {@code Grok} pattern + * @throws GrokException runtime expt + */ + public void register(String name, String pattern) { + name = Objects.requireNonNull(name).trim(); + pattern = Objects.requireNonNull(pattern).trim(); + + if (!name.isEmpty() && !pattern.isEmpty()) { + grokPatternDefinitions.put(name, pattern); + } + } + + /** Registers multiple pattern definitions. */ + public void register(Map patternDefinitions) { + Objects.requireNonNull(patternDefinitions); + patternDefinitions.forEach(this::register); + } + + /** + * Registers multiple pattern definitions from a given inputStream, and decoded as a UTF-8 source. + */ + public void register(InputStream input) throws IOException { + register(input, StandardCharsets.UTF_8); + } + + /** Registers multiple pattern definitions from a given inputStream. */ + public void register(InputStream input, Charset charset) throws IOException { + try (BufferedReader in = new BufferedReader(new InputStreamReader(input, charset))) { + in.lines() + .map(patternLinePattern::matcher) + .filter(Matcher::matches) + .forEach(m -> register(m.group(1), m.group(2))); + } + } + + /** Registers multiple pattern definitions from a given Reader. */ + public void register(Reader input) throws IOException { + new BufferedReader(input) + .lines() + .map(patternLinePattern::matcher) + .filter(Matcher::matches) + .forEach(m -> register(m.group(1), m.group(2))); + } + + public void registerDefaultPatterns() { + registerPatternFromClasspath("/patterns/patterns"); + } + + public void registerPatternFromClasspath(String path) throws GrokException { + registerPatternFromClasspath(path, StandardCharsets.UTF_8); + } + + /** registerPatternFromClasspath. */ + public void registerPatternFromClasspath(String path, Charset charset) throws GrokException { + final InputStream inputStream = this.getClass().getResourceAsStream(path); + try (Reader reader = new InputStreamReader(inputStream, charset)) { + register(reader); + } catch (IOException e) { + throw new GrokException(e.getMessage(), e); + } + } + + /** Compiles a given Grok pattern and returns a Grok object which can parse the pattern. */ + public Grok compile(String pattern) throws IllegalArgumentException { + return compile(pattern, false); + } + + public Grok compile(final String pattern, boolean namedOnly) throws IllegalArgumentException { + return compile(pattern, ZoneOffset.systemDefault(), namedOnly); + } + + /** + * Compiles a given Grok pattern and returns a Grok object which can parse the pattern. + * + * @param pattern : Grok pattern (ex: %{IP}) + * @param defaultTimeZone : time zone used to parse a timestamp when it doesn't contain the time + * zone + * @param namedOnly : Whether to capture named expressions only or not (i.e. %{IP:ip} but not + * ${IP}) + * @return a compiled pattern + * @throws IllegalArgumentException when pattern definition is invalid + */ + public Grok compile(final String pattern, ZoneId defaultTimeZone, boolean namedOnly) + throws IllegalArgumentException { + + if (StringUtils.isBlank(pattern)) { + throw new IllegalArgumentException("{pattern} should not be empty or null"); + } + + String namedRegex = pattern; + int index = 0; + // flag for infinite recursion + int iterationLeft = 1000; + Boolean continueIteration = true; + Map patternDefinitions = new HashMap<>(grokPatternDefinitions); + + // output + Map namedRegexCollection = new HashMap<>(); + + // Replace %{foo} with the regex (mostly group name regex) + // and then compile the regex + while (continueIteration) { + continueIteration = false; + if (iterationLeft <= 0) { + throw new IllegalArgumentException("Deep recursion pattern compilation of " + pattern); + } + iterationLeft--; + + Set namedGroups = GrokUtils.getNameGroups(GrokUtils.GROK_PATTERN.pattern()); + Matcher matcher = GrokUtils.GROK_PATTERN.matcher(namedRegex); + // Match %{Foo:bar} -> pattern name and subname + // Match %{Foo=regex} -> add new regex definition + if (matcher.find()) { + continueIteration = true; + Map group = GrokUtils.namedGroups(matcher, namedGroups); + if (group.get("definition") != null) { + patternDefinitions.put(group.get("pattern"), group.get("definition")); + group.put("name", group.get("name") + "=" + group.get("definition")); + } + int count = StringUtils.countMatches(namedRegex, "%{" + group.get("name") + "}"); + for (int i = 0; i < count; i++) { + String definitionOfPattern = patternDefinitions.get(group.get("pattern")); + if (definitionOfPattern == null) { + throw new IllegalArgumentException( + format("No definition for key '%s' found, aborting", group.get("pattern"))); + } + String replacement = String.format("(?%s)", index, definitionOfPattern); + if (namedOnly && group.get("subname") == null) { + replacement = String.format("(?:%s)", definitionOfPattern); + } + namedRegexCollection.put( + "name" + index, + (group.get("subname") != null ? group.get("subname") : group.get("name"))); + namedRegex = + StringUtils.replace(namedRegex, "%{" + group.get("name") + "}", replacement, 1); + // System.out.println(_expanded_pattern); + index++; + } + } + } + + if (namedRegex.isEmpty()) { + throw new IllegalArgumentException("Pattern not found"); + } + + return new Grok(pattern, namedRegex, namedRegexCollection, patternDefinitions, defaultTimeZone); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokUtils.java new file mode 100644 index 000000000..4b145bbbe --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/GrokUtils.java @@ -0,0 +1,59 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.grok; + +import java.util.LinkedHashMap; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * {@code GrokUtils} contain set of useful tools or methods. + * + * @since 0.0.6 + */ +public class GrokUtils { + + /** Extract Grok patter like %{FOO} to FOO, Also Grok pattern with semantic. */ + public static final Pattern GROK_PATTERN = + Pattern.compile( + "%\\{" + + "(?" + + "(?[A-z0-9]+)" + + "(?::(?[A-z0-9_:;,\\-\\/\\s\\.']+))?" + + ")" + + "(?:=(?" + + "(?:" + + "(?:[^{}]+|\\.+)+" + + ")+" + + ")" + + ")?" + + "\\}"); + + public static final Pattern NAMED_REGEX = Pattern.compile("\\(\\?<([a-zA-Z][a-zA-Z0-9]*)>"); + + /** getNameGroups. */ + public static Set getNameGroups(String regex) { + Set namedGroups = new LinkedHashSet<>(); + Matcher matcher = NAMED_REGEX.matcher(regex); + while (matcher.find()) { + namedGroups.add(matcher.group(1)); + } + return namedGroups; + } + + /** namedGroups. */ + public static Map namedGroups(Matcher matcher, Set groupNames) { + Map namedGroups = new LinkedHashMap<>(); + for (String groupName : groupNames) { + String groupValue = matcher.group(groupName); + namedGroups.put(groupName, groupValue); + } + return namedGroups; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Match.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Match.java new file mode 100644 index 000000000..1c02627c6 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/Match.java @@ -0,0 +1,241 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.grok; + +import org.opensearch.sql.common.grok.Converter.IConverter; +import org.opensearch.sql.common.grok.exception.GrokException; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.regex.Matcher; + +import static java.lang.String.format; + +/** + * {@code Match} is a representation in {@code Grok} world of your log. + * + * @since 0.0.1 + */ +public class Match { + private final CharSequence subject; + private final Grok grok; + private final Matcher match; + private final int start; + private final int end; + private boolean keepEmptyCaptures = true; + private Map capture = Collections.emptyMap(); + + /** Create a new {@code Match} object. */ + public Match(CharSequence subject, Grok grok, Matcher match, int start, int end) { + this.subject = subject; + this.grok = grok; + this.match = match; + this.start = start; + this.end = end; + } + + /** Create Empty grok matcher. */ + public static final Match EMPTY = new Match("", null, null, 0, 0); + + public Matcher getMatch() { + return match; + } + + public int getStart() { + return start; + } + + public int getEnd() { + return end; + } + + /** Ignore empty captures. */ + public void setKeepEmptyCaptures(boolean ignore) { + // clear any cached captures + if (capture.size() > 0) { + capture = new LinkedHashMap<>(); + } + this.keepEmptyCaptures = ignore; + } + + public boolean isKeepEmptyCaptures() { + return this.keepEmptyCaptures; + } + + /** + * Retrurn the single line of log. + * + * @return the single line of log + */ + public CharSequence getSubject() { + return subject; + } + + /** + * Match to the subject the regex and save the matched element into a map. + * + *

Multiple values for the same key are stored as list. + */ + public Map capture() { + return capture(false); + } + + /** + * Private implementation of captureFlattened and capture. + * + * @param flattened will it flatten values. + * @return the matched elements. + * @throws GrokException if a keys has multiple non-null values, but only if flattened is set to + * true. + */ + private Map capture(boolean flattened) throws GrokException { + if (match == null) { + return Collections.emptyMap(); + } + + if (!capture.isEmpty()) { + return capture; + } + + capture = new LinkedHashMap<>(); + + // _capture.put("LINE", this.line); + // _capture.put("LENGTH", this.line.length() +""); + + Map mappedw = GrokUtils.namedGroups(this.match, this.grok.namedGroups); + + mappedw.forEach( + (key, valueString) -> { + String id = this.grok.getNamedRegexCollectionById(key); + if (id != null && !id.isEmpty()) { + key = id; + } + + if ("UNWANTED".equals(key)) { + return; + } + + Object value = valueString; + if (valueString != null) { + IConverter converter = grok.converters.get(key); + + if (converter != null) { + key = Converter.extractKey(key); + try { + value = converter.convert(valueString); + } catch (Exception e) { + capture.put(key + "_grokfailure", e.toString()); + } + + if (value instanceof String) { + value = cleanString((String) value); + } + } else { + value = cleanString(valueString); + } + } else if (!isKeepEmptyCaptures()) { + return; + } + + if (capture.containsKey(key)) { + Object currentValue = capture.get(key); + + if (flattened) { + if (currentValue == null && value != null) { + capture.put(key, value); + } + if (currentValue != null && value != null) { + throw new GrokException( + format( + "key '%s' has multiple non-null values, this is not allowed in flattened" + + " mode, values:'%s', '%s'", + key, currentValue, value)); + } + } else { + if (currentValue instanceof List) { + @SuppressWarnings("unchecked") + List cvl = (List) currentValue; + cvl.add(value); + } else { + List list = new ArrayList(); + list.add(currentValue); + list.add(value); + capture.put(key, list); + } + } + } else { + capture.put(key, value); + } + }); + + capture = Collections.unmodifiableMap(capture); + + return capture; + } + + /** + * Match to the subject the regex and save the matched element into a map + * + *

Multiple values to the same key are flattened to one value: the sole non-null value will be + * captured. Should there be multiple non-null values a RuntimeException is being thrown. + * + *

This can be used in cases like: (foo (.*:message) bar|bar (.*:message) foo) where the regexp + * guarantees that only one value will be captured. + * + *

See also {@link #capture} which returns multiple values of the same key as list. + * + * @return the matched elements + * @throws GrokException if a keys has multiple non-null values. + */ + public Map captureFlattened() throws GrokException { + return capture(true); + } + + /** + * remove from the string the quote and double quote. + * + * @param value string to pure: "my/text" + * @return unquoted string: my/text + */ + private String cleanString(String value) { + if (value == null || value.isEmpty()) { + return value; + } + + char firstChar = value.charAt(0); + char lastChar = value.charAt(value.length() - 1); + + if (firstChar == lastChar && (firstChar == '"' || firstChar == '\'')) { + if (value.length() <= 2) { + return ""; + } else { + int found = 0; + for (int i = 1; i < value.length() - 1; i++) { + if (value.charAt(i) == firstChar) { + found++; + } + } + if (found == 0) { + return value.substring(1, value.length() - 1); + } + } + } + + return value; + } + + /** + * Util fct. + * + * @return boolean + */ + public Boolean isNull() { + return this.match == null; + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/exception/GrokException.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/exception/GrokException.java new file mode 100644 index 000000000..0e9d6d2dd --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/common/grok/exception/GrokException.java @@ -0,0 +1,50 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.common.grok.exception; + +/** + * Signals that an {@code Grok} exception of some sort has occurred. This class is the general class + * of exceptions produced by failed or interrupted Grok operations. + * + * @since 0.0.4 + */ +public class GrokException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + /** Creates a new GrokException. */ + public GrokException() { + super(); + } + + /** + * Constructs a new GrokException. + * + * @param message the reason for the exception + * @param cause the underlying Throwable that caused this exception to be thrown. + */ + public GrokException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Constructs a new GrokException. + * + * @param message the reason for the exception + */ + public GrokException(String message) { + super(message); + } + + /** + * Constructs a new GrokException. + * + * @param cause the underlying Throwable that caused this exception to be thrown. + */ + public GrokException(Throwable cause) { + super(cause); + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/GrokExpression.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/GrokExpression.java deleted file mode 100644 index 4cf79cbae..000000000 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/GrokExpression.java +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.expression.parse; - -import lombok.EqualsAndHashCode; -import lombok.ToString; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.sql.common.grok.Grok; -import org.opensearch.sql.common.grok.GrokCompiler; -import org.opensearch.sql.common.grok.Match; -import org.opensearch.sql.data.model.ExprStringValue; -import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.exception.ExpressionEvaluationException; -import org.opensearch.sql.expression.Expression; - -import java.util.List; -import java.util.Map; -import java.util.stream.Collectors; - -/** GrokExpression with grok patterns. */ -public class GrokExpression extends ParseExpression { - private static final Logger log = LogManager.getLogger(GrokExpression.class); - private static final GrokCompiler grokCompiler = GrokCompiler.newInstance(); - - static { - grokCompiler.registerDefaultPatterns(); - } - - @EqualsAndHashCode.Exclude private final Grok grok; - - /** - * GrokExpression. - * - * @param sourceField source text field - * @param pattern pattern used for parsing - * @param identifier derived field - */ - public GrokExpression(Expression sourceField, Expression pattern, Expression identifier) { - super("grok", sourceField, pattern, identifier); - this.grok = grokCompiler.compile(pattern.valueOf().stringValue()); - } - - @Override - ExprValue parseValue(ExprValue value) throws ExpressionEvaluationException { - String rawString = value.stringValue(); - Match grokMatch = grok.match(rawString); - Map capture = grokMatch.capture(); - Object match = capture.get(identifierStr); - if (match != null) { - return new ExprStringValue(match.toString()); - } - log.debug("failed to extract pattern {} from input ***", grok.getOriginalGrokPattern()); - return new ExprStringValue(""); - } - - /** - * Get list of derived fields based on parse pattern. - * - * @param pattern pattern used for parsing - * @return list of names of the derived fields - */ - public static List getNamedGroupCandidates(String pattern) { - Grok grok = grokCompiler.compile(pattern); - return grok.namedGroups.stream() - .map(grok::getNamedRegexCollectionById) - .filter(group -> !group.equals("UNWANTED")) - .collect(Collectors.toUnmodifiableList()); - } -} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/ParseExpression.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/ParseExpression.java deleted file mode 100644 index 6e2456ecc..000000000 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/ParseExpression.java +++ /dev/null @@ -1,75 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.expression.parse; - -import com.google.common.collect.ImmutableList; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; -import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.data.model.ExprValueUtils; -import org.opensearch.sql.data.type.ExprCoreType; -import org.opensearch.sql.data.type.ExprType; -import org.opensearch.sql.exception.ExpressionEvaluationException; -import org.opensearch.sql.exception.SemanticCheckException; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.ExpressionNodeVisitor; -import org.opensearch.sql.expression.FunctionExpression; -import org.opensearch.sql.expression.env.Environment; -import org.opensearch.sql.expression.function.FunctionName; - -/** ParseExpression. */ -@EqualsAndHashCode(callSuper = false) -@ToString -public abstract class ParseExpression extends FunctionExpression { - @Getter protected final Expression sourceField; - protected final Expression pattern; - @Getter protected final Expression identifier; - protected final String identifierStr; - - /** - * ParseExpression. - * - * @param functionName name of function expression - * @param sourceField source text field - * @param pattern pattern used for parsing - * @param identifier derived field - */ - public ParseExpression( - String functionName, Expression sourceField, Expression pattern, Expression identifier) { - super(FunctionName.of(functionName), ImmutableList.of(sourceField, pattern, identifier)); - this.sourceField = sourceField; - this.pattern = pattern; - this.identifier = identifier; - this.identifierStr = identifier.valueOf().stringValue(); - } - - @Override - public ExprValue valueOf(Environment valueEnv) { - ExprValue value = valueEnv.resolve(sourceField); - if (value.isNull() || value.isMissing()) { - return ExprValueUtils.nullValue(); - } - try { - return parseValue(value); - } catch (ExpressionEvaluationException e) { - throw new SemanticCheckException( - String.format("failed to parse field \"%s\" with type [%s]", sourceField, value.type())); - } - } - - @Override - public ExprType type() { - return ExprCoreType.STRING; - } - - @Override - public T accept(ExpressionNodeVisitor visitor, C context) { - return visitor.visitParse(this, context); - } - - abstract ExprValue parseValue(ExprValue value) throws ExpressionEvaluationException; -} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/PatternsExpression.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/PatternsExpression.java deleted file mode 100644 index fedeeef7e..000000000 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/PatternsExpression.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.expression.parse; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableSet; -import lombok.EqualsAndHashCode; -import lombok.ToString; -import org.opensearch.sql.data.model.ExprStringValue; -import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.exception.ExpressionEvaluationException; -import org.opensearch.sql.expression.Expression; - -import java.util.List; -import java.util.Objects; -import java.util.regex.Pattern; - -/** PatternsExpression with regex filter. */ -@EqualsAndHashCode(callSuper = true) -@ToString -public class PatternsExpression extends ParseExpression { - /** Default name of the derived field. */ - public static final String DEFAULT_NEW_FIELD = "patterns_field"; - - private static final ImmutableSet DEFAULT_IGNORED_CHARS = - ImmutableSet.copyOf( - "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" - .chars() - .mapToObj(c -> (char) c) - .toArray(Character[]::new)); - private final boolean useCustomPattern; - @EqualsAndHashCode.Exclude private Pattern pattern; - - /** - * PatternsExpression. - * - * @param sourceField source text field - * @param pattern pattern used for parsing - * @param identifier derived field - */ - public PatternsExpression(Expression sourceField, Expression pattern, Expression identifier) { - super("patterns", sourceField, pattern, identifier); - String patternStr = pattern.valueOf().stringValue(); - useCustomPattern = !patternStr.isEmpty(); - if (useCustomPattern) { - this.pattern = Pattern.compile(patternStr); - } - } - - @Override - ExprValue parseValue(ExprValue value) throws ExpressionEvaluationException { - String rawString = value.stringValue(); - if (useCustomPattern) { - return new ExprStringValue(pattern.matcher(rawString).replaceAll("")); - } - - char[] chars = rawString.toCharArray(); - int pos = 0; - for (int i = 0; i < chars.length; i++) { - if (!DEFAULT_IGNORED_CHARS.contains(chars[i])) { - chars[pos++] = chars[i]; - } - } - return new ExprStringValue(new String(chars, 0, pos)); - } - - /** - * Get list of derived fields. - * - * @param identifier identifier used to generate the field name - * @return list of names of the derived fields - */ - public static List getNamedGroupCandidates(String identifier) { - return ImmutableList.of(Objects.requireNonNullElse(identifier, DEFAULT_NEW_FIELD)); - } -} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/RegexExpression.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/RegexExpression.java deleted file mode 100644 index aab51cd4b..000000000 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/expression/parse/RegexExpression.java +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.expression.parse; - -import com.google.common.collect.ImmutableList; -import lombok.EqualsAndHashCode; -import lombok.Getter; -import lombok.ToString; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.opensearch.sql.data.model.ExprStringValue; -import org.opensearch.sql.data.model.ExprValue; -import org.opensearch.sql.exception.ExpressionEvaluationException; -import org.opensearch.sql.expression.Expression; - -import java.util.List; -import java.util.regex.Matcher; -import java.util.regex.Pattern; - -/** RegexExpression with regex and named capture group. */ -public class RegexExpression extends ParseExpression { - private static final Logger log = LogManager.getLogger(RegexExpression.class); - private static final Pattern GROUP_PATTERN = Pattern.compile("\\(\\?<([a-zA-Z][a-zA-Z0-9]*)>"); - private final Pattern regexPattern; - - /** - * RegexExpression. - * - * @param sourceField source text field - * @param pattern pattern used for parsing - * @param identifier derived field - */ - public RegexExpression(Expression sourceField, Expression pattern, Expression identifier) { - super("regex", sourceField, pattern, identifier); - this.regexPattern = Pattern.compile(pattern.valueOf().stringValue()); - } - - @Override - public parseValue(String value) throws ExpressionEvaluationException { - Matcher matcher = regexPattern.matcher(value); - if (matcher.matches()) { - return new ExprStringValue(matcher.group(identifierStr)); - } - log.debug("failed to extract pattern {} from input ***", regexPattern.pattern()); - return new ExprStringValue(""); - } - - /** - * Get list of derived fields based on parse pattern. - * - * @param pattern pattern used for parsing - * @return list of names of the derived fields - */ - public static List getNamedGroupCandidates(String pattern) { - ImmutableList.Builder namedGroups = ImmutableList.builder(); - Matcher m = GROUP_PATTERN.matcher(pattern); - while (m.find()) { - namedGroups.add(m.group(1)); - } - return namedGroups.build(); - } -} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index b2169726f..f4a96bc27 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -15,8 +15,10 @@ import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; +import org.apache.spark.sql.catalyst.expressions.RegExpExtract; import org.apache.spark.sql.catalyst.expressions.SortDirection; import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.apache.spark.sql.catalyst.expressions.StringRegexExpression; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; import org.apache.spark.sql.catalyst.plans.logical.Deduplicate; @@ -68,7 +70,6 @@ import org.opensearch.sql.ast.tree.RareTopN; import org.opensearch.sql.ast.tree.Relation; import org.opensearch.sql.ast.tree.Sort; -import org.opensearch.sql.expression.parse.ParseExpression; import org.opensearch.sql.ast.tree.TopAggregation; import org.opensearch.sql.ppl.utils.AggregatorTranslator; import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; @@ -89,6 +90,7 @@ import static java.util.Collections.emptyList; import static java.util.List.of; +import static org.apache.spark.sql.types.DataTypes.StringType; import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; @@ -306,15 +308,19 @@ public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { } private LogicalPlan visitParseCommand(Parse node, Expression sourceField, ParseMethod parseMethod, Map arguments, String pattern, CatalystPlanContext context) { - ParseUtils.getNamedGroupCandidates(parseMethod, pattern, arguments) + AttributeReference column = new AttributeReference(sourceField.nodeName(), StringType, true, Metadata.empty(), NamedExpression.newExprId(), seq(of())); + List namedGroupCandidates = ParseUtils.getNamedGroupCandidates(parseMethod, pattern, arguments); + namedGroupCandidates .forEach( group -> { - ParseExpression expr = - ParseUtils.createParseExpression( - parseMethod, sourceField, pattern, group); - context.define(new AttributeReference(group, expr.dataType(), true, Metadata.empty(), NamedExpression.newExprId(), seq(emptyList()))new Ngroup, expr.type()); - context.getNamedParseExpressions().add(new NamedExpression(group, expr)); + ParseUtils.ParseExpression expr = + ParseUtils.createParseExpression(parseMethod, pattern, group); + context.define(new AttributeReference(group, StringType, true, Metadata.empty(), NamedExpression.newExprId(), seq(emptyList()))); + //todo add exp details to the regExpExtract + RegExpExtract regExpExtract = new RegExpExtract(column, new org.apache.spark.sql.catalyst.expressions.Literal(pattern, StringType), + new org.apache.spark.sql.catalyst.expressions.Literal(group, StringType)); }); + return context.getPlan(); } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java index b907050a4..58700b06a 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java @@ -5,39 +5,40 @@ package org.opensearch.sql.ppl.utils; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import org.opensearch.sql.ast.expression.Literal; import org.opensearch.sql.ast.expression.ParseMethod; -import org.opensearch.sql.expression.Expression; -import org.opensearch.sql.expression.parse.GrokExpression; -import org.opensearch.sql.expression.parse.ParseExpression; -import org.opensearch.sql.expression.parse.PatternsExpression; -import org.opensearch.sql.expression.parse.RegexExpression; +import org.opensearch.sql.common.grok.Grok; +import org.opensearch.sql.common.grok.GrokCompiler; +import org.opensearch.sql.common.grok.Match; import java.util.List; import java.util.Map; +import java.util.Objects; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import java.util.stream.Collectors; -/** Utils for {@link ParseExpression}. */ public class ParseUtils { private static final String NEW_FIELD_KEY = "new_field"; - private static final Map FACTORY_MAP = - ImmutableMap.of( - ParseMethod.REGEX, RegexExpression::new, - ParseMethod.GROK, GrokExpression::new, - ParseMethod.PATTERNS, PatternsExpression::new); /** * Construct corresponding ParseExpression by {@link ParseMethod}. * * @param parseMethod method used to parse - * @param sourceField source text field * @param pattern pattern used for parsing * @param identifier derived field * @return {@link ParseExpression} */ public static ParseExpression createParseExpression( - ParseMethod parseMethod, Expression sourceField, Expression pattern, Expression identifier) { - return FACTORY_MAP.get(parseMethod).initialize(sourceField, pattern, identifier); + ParseMethod parseMethod, String pattern, String identifier) { + switch (parseMethod) { + case GROK: return new GrokExpression(pattern, identifier); + case PATTERNS: return new PatternsExpression(pattern, identifier); + default: return new RegexExpression(pattern, identifier); + } } /** @@ -47,7 +48,7 @@ public static ParseExpression createParseExpression( * @return list of names of the derived fields */ public static List getNamedGroupCandidates( - ParseMethod parseMethod, String pattern, Map arguments) { + ParseMethod parseMethod, String pattern, Map arguments) { switch (parseMethod) { case REGEX: return RegexExpression.getNamedGroupCandidates(pattern); @@ -55,14 +56,138 @@ public static List getNamedGroupCandidates( return GrokExpression.getNamedGroupCandidates(pattern); default: return PatternsExpression.getNamedGroupCandidates( - arguments.containsKey(NEW_FIELD_KEY) - ? (String) arguments.get(NEW_FIELD_KEY).getValue() - : null); + arguments.containsKey(NEW_FIELD_KEY) + ? (String) arguments.get(NEW_FIELD_KEY).getValue() + : null); + } + } + + public static abstract class ParseExpression { + abstract String parseValue(String value); + } + + public static class RegexExpression extends ParseExpression{ + private static final Pattern GROUP_PATTERN = Pattern.compile("\\(\\?<([a-zA-Z][a-zA-Z0-9]*)>"); + private final Pattern regexPattern; + protected final String identifier; + + public RegexExpression(String patterns, String identifier) { + this.regexPattern = Pattern.compile(patterns); + this.identifier = identifier; + } + + /** + * Get list of derived fields based on parse pattern. + * + * @param pattern pattern used for parsing + * @return list of names of the derived fields + */ + public static List getNamedGroupCandidates(String pattern) { + ImmutableList.Builder namedGroups = ImmutableList.builder(); + Matcher m = GROUP_PATTERN.matcher(pattern); + while (m.find()) { + namedGroups.add(m.group(1)); + } + return namedGroups.build(); + } + + @Override + public String parseValue(String value) { + Matcher matcher = regexPattern.matcher(value); + if (matcher.matches()) { + return matcher.group(identifier); + } + return ""; + } + } + + public static class GrokExpression extends ParseExpression{ + private static final GrokCompiler grokCompiler = GrokCompiler.newInstance(); + private final Grok grok; + private final String identifier; + + public GrokExpression(String pattern, String identifier) { + this.grok = grokCompiler.compile(pattern); + this.identifier = identifier; + } + + @Override + public String parseValue(String value) { + Match grokMatch = grok.match(value); + Map capture = grokMatch.capture(); + Object match = capture.get(identifier); + if (match != null) { + return match.toString(); + } + return ""; } + + /** + * Get list of derived fields based on parse pattern. + * + * @param pattern pattern used for parsing + * @return list of names of the derived fields + */ + public static List getNamedGroupCandidates(String pattern) { + Grok grok = grokCompiler.compile(pattern); + return grok.namedGroups.stream() + .map(grok::getNamedRegexCollectionById) + .filter(group -> !group.equals("UNWANTED")) + .collect(Collectors.toUnmodifiableList()); + } + } + + public static class PatternsExpression extends ParseExpression{ + public static final String DEFAULT_NEW_FIELD = "patterns_field"; + + private static final ImmutableSet DEFAULT_IGNORED_CHARS = + ImmutableSet.copyOf( + "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + .chars() + .mapToObj(c -> (char) c) + .toArray(Character[]::new)); + private final boolean useCustomPattern; + private Pattern pattern; - private interface ParseExpressionFactory { - ParseExpression initialize( - Expression sourceField, Expression expression, Expression identifier); + /** + * PatternsExpression. + * + * @param pattern pattern used for parsing + * @param identifier derived field + */ + public PatternsExpression(String pattern, String identifier) { + useCustomPattern = !pattern.isEmpty(); + if (useCustomPattern) { + this.pattern = Pattern.compile(pattern); + } + } + + @Override + public String parseValue(String value) { + if (useCustomPattern) { + return pattern.matcher(value).replaceAll(""); + } + + char[] chars = value.toCharArray(); + int pos = 0; + for (int i = 0; i < chars.length; i++) { + if (!DEFAULT_IGNORED_CHARS.contains(chars[i])) { + chars[pos++] = chars[i]; + } + } + return new String(chars, 0, pos); + } + + /** + * Get list of derived fields. + * + * @param identifier identifier used to generate the field name + * @return list of names of the derived fields + */ + public static List getNamedGroupCandidates(String identifier) { + return ImmutableList.of(Objects.requireNonNullElse(identifier, DEFAULT_NEW_FIELD)); + } } + } From 5158ab0057e3f5b6b4b16cef1c46ecd16bab3fda Mon Sep 17 00:00:00 2001 From: YANGDB Date: Thu, 22 Aug 2024 11:14:55 -0700 Subject: [PATCH 4/9] update tests Signed-off-by: YANGDB Signed-off-by: YANGDB --- .../src/main/antlr4/OpenSearchPPLParser.g4 | 3 + .../sql/ppl/CatalystQueryPlanVisitor.java | 51 ++++----- .../opensearch/sql/ppl/utils/ParseUtils.java | 45 ++++++++ ...LLogicalPlanParseTranslatorTestSuite.scala | 107 ++++++++++++++---- 4 files changed, 160 insertions(+), 46 deletions(-) diff --git a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 index f4065be6d..3ce7cef7c 100644 --- a/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 +++ b/ppl-spark-integration/src/main/antlr4/OpenSearchPPLParser.g4 @@ -41,6 +41,9 @@ commands | topCommand | rareCommand | evalCommand + | grokCommand + | parseCommand + | patternsCommand ; searchCommand diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index f4a96bc27..0f470185c 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -11,6 +11,7 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.AttributeReference; import org.apache.spark.sql.catalyst.expressions.Ascending$; +import org.apache.spark.sql.catalyst.expressions.Coalesce; import org.apache.spark.sql.catalyst.expressions.Descending$; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.NamedExpression; @@ -90,6 +91,7 @@ import static java.util.Collections.emptyList; import static java.util.List.of; +import static org.apache.spark.sql.types.DataTypes.IntegerType; import static org.apache.spark.sql.types.DataTypes.StringType; import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; @@ -291,36 +293,33 @@ public LogicalPlan visitParse(Parse node, CatalystPlanContext context) { ParseMethod parseMethod = node.getParseMethod(); java.util.Map arguments = node.getArguments(); String pattern = (String) node.getPattern().getValue(); - - List aliases = new ArrayList<>(); - switch (node.getParseMethod()) { - case GROK: - throw new IllegalStateException("Not Supported operation : GROK"); - case PATTERNS: - throw new IllegalStateException("Not Supported operation : PATTERNS"); - case REGEX: - return visitParseCommand(node, sourceField, parseMethod, arguments, pattern, context); - default: - throw new IllegalArgumentException("Invalid parse command name: " + node.getParseMethod() - + " Syntax: [parse ] "); - - } + return visitParseCommand(node, sourceField, parseMethod, arguments, pattern, context); } private LogicalPlan visitParseCommand(Parse node, Expression sourceField, ParseMethod parseMethod, Map arguments, String pattern, CatalystPlanContext context) { - AttributeReference column = new AttributeReference(sourceField.nodeName(), StringType, true, Metadata.empty(), NamedExpression.newExprId(), seq(of())); List namedGroupCandidates = ParseUtils.getNamedGroupCandidates(parseMethod, pattern, arguments); - namedGroupCandidates - .forEach( - group -> { - ParseUtils.ParseExpression expr = - ParseUtils.createParseExpression(parseMethod, pattern, group); - context.define(new AttributeReference(group, StringType, true, Metadata.empty(), NamedExpression.newExprId(), seq(emptyList()))); - //todo add exp details to the regExpExtract - RegExpExtract regExpExtract = new RegExpExtract(column, new org.apache.spark.sql.catalyst.expressions.Literal(pattern, StringType), - new org.apache.spark.sql.catalyst.expressions.Literal(group, StringType)); - }); - return context.getPlan(); + String cleanedPattern = ParseUtils.extractPatterns(parseMethod, pattern, namedGroupCandidates); + for (int i = 0; i < namedGroupCandidates.size(); i++) { + String group = namedGroupCandidates.get(i); + //first create the regExp + RegExpExtract regExpExtract = new RegExpExtract(sourceField, + org.apache.spark.sql.catalyst.expressions.Literal.create(cleanedPattern, StringType), + org.apache.spark.sql.catalyst.expressions.Literal.create(i+1, IntegerType)); + //next create Coalesce to handle potential null values + Coalesce coalesce = new Coalesce(seq(regExpExtract)); + //next Alias the extracted fields + context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(coalesce, + group, + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()))); + } + Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + // build the plan with the projection step + LogicalPlan child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + return child; } @Override diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java index 58700b06a..54b43db0e 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/ParseUtils.java @@ -62,6 +62,25 @@ public static List getNamedGroupCandidates( } } + /** + * extract the cleaner pattern without the additional fields + * @param parseMethod + * @param pattern + * @param columns + * @return + */ + public static String extractPatterns( + ParseMethod parseMethod, String pattern, List columns) { + switch (parseMethod) { + case REGEX: + return RegexExpression.extractPattern(pattern, columns); + case GROK: + return GrokExpression.extractPattern(pattern, columns); + default: + return PatternsExpression.extractPattern(pattern, columns); + } + } + public static abstract class ParseExpression { abstract String parseValue(String value); } @@ -99,6 +118,23 @@ public String parseValue(String value) { } return ""; } + + public static String extractPattern(String patterns, List columns) { + StringBuilder result = new StringBuilder(); + Matcher matcher = GROUP_PATTERN.matcher(patterns); + + int lastEnd = 0; + while (matcher.find()) { + String groupName = matcher.group(1); + if (columns.contains(groupName)) { + result.append(patterns, lastEnd, matcher.start()); + result.append("("); + lastEnd = matcher.end(); + } + } + result.append(patterns.substring(lastEnd)); + return result.toString(); + } } public static class GrokExpression extends ParseExpression{ @@ -136,6 +172,10 @@ public static List getNamedGroupCandidates(String pattern) { .collect(Collectors.toUnmodifiableList()); } + public static String extractPattern(String patterns, List columns) { + //todo implement + return patterns; + } } public static class PatternsExpression extends ParseExpression{ @@ -188,6 +228,11 @@ public String parseValue(String value) { public static List getNamedGroupCandidates(String identifier) { return ImmutableList.of(Objects.requireNonNullElse(identifier, DEFAULT_NEW_FIELD)); } + + public static String extractPattern(String patterns, List columns) { + //todo implement + return patterns; + } } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala index 469841f24..aa37ea484 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala @@ -6,10 +6,11 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.ScalaReflection.universe.Star import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, NamedExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Coalesce, GreaterThan, Literal, NamedExpression, NullsFirst, RegExpExtract, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.Project +import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project, Sort} import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} @@ -28,14 +29,21 @@ class PPLLogicalPlanParseTranslatorTestSuite val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, "source=t | parse email '.+@(?.+)' | fields email, host", false), + plan(pplParser, "source=t | parse email '.+@(?.+)' | fields email, host", isExplain = false), context) - val evalProjectList: Seq[NamedExpression] = - Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(1), "b")()) + + val emailAttribute = UnresolvedAttribute("email") + val hostAttribute = UnresolvedAttribute("host") + val hostExpression = Alias( + Coalesce( + Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal("1")))), "host")() val expectedPlan = Project( - seq(UnresolvedAttribute("c")), - Project(evalProjectList, UnresolvedRelation(Seq("t")))) - comparePlans(expectedPlan, logPlan, checkAnalysis = false) + Seq(emailAttribute, hostAttribute), + Project( + Seq(emailAttribute, hostExpression), + UnresolvedRelation(Seq("t")) + )) + assert(compareByString(expectedPlan) === compareByString(logPlan)) } test("test parse email expression") { @@ -44,25 +52,84 @@ class PPLLogicalPlanParseTranslatorTestSuite planTransformer.visit( plan(pplParser, "source=t | parse email '.+@(?.+)' | fields email", false), context) - val evalProjectList: Seq[NamedExpression] = - Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(1), "b")()) + + val emailAttribute = UnresolvedAttribute("email") + val hostExpression = Alias( + Coalesce( + Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal("1")))), "email")() val expectedPlan = Project( - seq(UnresolvedAttribute("c")), - Project(evalProjectList, UnresolvedRelation(Seq("t")))) - comparePlans(expectedPlan, logPlan, checkAnalysis = false) + Seq(emailAttribute), + Project( + Seq(emailAttribute, hostExpression), + UnresolvedRelation(Seq("t")) + )) + assert(compareByString(expectedPlan) === compareByString(logPlan)) } - + test("test parse email expression, generate new host field and eval result") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( plan(pplParser, "source=t | parse email '.+@(?.+)' | eval eval_result=1 | fields host, eval_result", false), context) - val evalProjectList: Seq[NamedExpression] = - Seq(UnresolvedStar(None), Alias(Literal(1), "a")(), Alias(Literal(1), "b")()) + + val emailAttribute = UnresolvedAttribute("email") + val hostAttribute = UnresolvedAttribute("host") + val evalResultAttribute = UnresolvedAttribute("eval_result") + + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal("1")))), + "host")() + + val evalResultExpression = Alias(Literal(1), "eval_result")() + val expectedPlan = Project( - seq(UnresolvedAttribute("c")), - Project(evalProjectList, UnresolvedRelation(Seq("t")))) - comparePlans(expectedPlan, logPlan, checkAnalysis = false) + Seq(hostAttribute, evalResultAttribute), + Project( + Seq(UnresolvedStar(None), evalResultExpression), + Project( + Seq(emailAttribute, hostExpression), + UnresolvedRelation(Seq("t")) + ) + ) + ) + assert(compareByString(expectedPlan) === compareByString(logPlan)) } -} + + test("test parse email & host expressions including cast and sort commands") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | parse address '(?\\d+) (?.+)' | where streetNumber > 500 | sort num(streetNumber) | fields streetNumber, street", false), + context) + + val addressAttribute = UnresolvedAttribute("address") + val streetNumberAttribute = UnresolvedAttribute("streetNumber") + val streetAttribute = UnresolvedAttribute("street") + + val streetNumberExpression = Alias( + Coalesce(Seq(RegExpExtract(addressAttribute, Literal("(\\d+) (.+)"), Literal("1")))), + "streetNumber" + )() + + val streetExpression = Alias( + Coalesce(Seq(RegExpExtract(addressAttribute, Literal("(\\d+) (.+)"), Literal("2")))), + "street")() + + val expectedPlan = Project( + Seq(streetNumberAttribute, streetAttribute), + Sort( + Seq(SortOrder(streetNumberAttribute, Ascending, NullsFirst, Seq.empty)), + global = true, + Filter( + GreaterThan(streetNumberAttribute, Literal(500)), + Project( + Seq(addressAttribute, streetNumberExpression, streetExpression), + UnresolvedRelation(Seq("t")) + ) + ) + ) + ) + + assert(compareByString(expectedPlan) === compareByString(logPlan)) + }} From 954e793e31ddbaf411988a9b6fccc30a6e540fe6 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Thu, 22 Aug 2024 12:05:12 -0700 Subject: [PATCH 5/9] update tests with more complex tests Signed-off-by: YANGDB Signed-off-by: YANGDB --- .../flint/spark/FlintSparkSuite.scala | 25 +++--- .../spark/ppl/FlintSparkPPLParseITSuite.scala | 85 ++++++++++++++++--- ...LLogicalPlanParseTranslatorTestSuite.scala | 33 ++++++- 3 files changed, 118 insertions(+), 25 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index 0f35f6aed..bb4d53a59 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -106,6 +106,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | CREATE TABLE $testTable | ( | name STRING, + | age INT, | email STRING, | street_address STRING | ) @@ -117,24 +118,24 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit |""".stripMargin) val data = Seq( - ("Alice", "alice@example.com", "123 Main St, Seattle", 2023, 4), - ("Bob", "bob@test.org", "456 Elm St, Portland", 2023, 5), - ("Charlie", "charlie@domain.net", "789 Pine St, San Francisco", 2023, 4), - ("David", "david@anotherdomain.com", "101 Maple St, New York", 2023, 5), - ("Eve", "eve@examples.com", "202 Oak St, Boston", 2023, 4), - ("Frank", "frank@sample.org", "303 Cedar St, Austin", 2023, 5), - ("Grace", "grace@demo.net", "404 Birch St, Chicago", 2023, 4), - ("Hank", "hank@demonstration.com", "505 Spruce St, Miami", 2023, 5), - ("Ivy", "ivy@examples.org", "606 Fir St, Denver", 2023, 4), - ("Jack", "jack@sample.net", "707 Ash St, Seattle", 2023, 5) + ("Alice", 30, "alice@example.com", "123 Main St, Seattle", 2023, 4), + ("Bob", 55, "bob@test.org", "456 Elm St, Portland", 2023, 5), + ("Charlie", 65, "charlie@domain.net", "789 Pine St, San Francisco", 2023, 4), + ("David", 19, "david@anotherdomain.com", "101 Maple St, New York", 2023, 5), + ("Eve", 21, "eve@examples.com", "202 Oak St, Boston", 2023, 4), + ("Frank", 76, "frank@sample.org", "303 Cedar St, Austin", 2023, 5), + ("Grace", 41, "grace@demo.net", "404 Birch St, Chicago", 2023, 4), + ("Hank", 32, "hank@demonstration.com", "505 Spruce St, Miami", 2023, 5), + ("Ivy", 9, "ivy@examples.org", "606 Fir St, Denver", 2023, 4), + ("Jack", 12, "jack@sample.net", "707 Ash St, Seattle", 2023, 5) ) - data.foreach { case (name, email, street_address, year, month) => + data.foreach { case (name, age, email, street_address, year, month) => spark.sql( s""" | INSERT INTO $testTable | PARTITION (year=$year, month=$month) - | VALUES ('$name', '$email', '$street_address') + | VALUES ('$name', $age, '$email', '$street_address') | """.stripMargin) } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala index d39bdc522..f0578cfab 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala @@ -6,8 +6,8 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Descending, Literal, SortOrder} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project, Sort} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Coalesce, Descending, GreaterThan, Literal, NullsLast, RegExpExtract, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort} import org.apache.spark.sql.streaming.StreamTest import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq @@ -38,15 +38,28 @@ class FlintSparkPPLParseITSuite } test("test parse email expressions parsing") { - val frame = sql(s""" - | source = $testTable | parse email '.+@(?.+)' | fields email, host ; + val frame = sql( + s""" + | source = $testTable| parse email '.+@(?.+)' | fields email, host ; | """.stripMargin) // Retrieve the results val results: Array[Row] = frame.collect() // Define the expected results - val expectedResults: Array[Row] = - Array(Row("Jake", 70), Row("Hello", 30), Row("John", 25), Row("Jane", 20)) + // Define the expected results + val expectedResults: Array[Row] = Array( + Row("charlie@domain.net", "domain.net"), + Row("david@anotherdomain.com", "anotherdomain.com"), + Row("hank@demonstration.com", "demonstration.com"), + Row("alice@example.com", "example.com"), + Row("frank@sample.org", "sample.org"), + Row("grace@demo.net", "demo.net"), + Row("jack@sample.net", "sample.net"), + Row("eve@examples.com", "examples.com"), + Row("ivy@examples.org", "examples.org"), + Row("bob@test.org", "test.org") + ) + // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) assert(results.sorted.sameElements(expectedResults.sorted)) @@ -54,12 +67,60 @@ class FlintSparkPPLParseITSuite // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical // Define the expected logical plan - val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) - val fieldsProjectList = Seq(UnresolvedAttribute("name"), UnresolvedAttribute("age")) - val evalProjectList = Seq(UnresolvedStar(None), Alias(Literal(1), "col")()) - val expectedPlan = Project(fieldsProjectList, Project(evalProjectList, table)) - // Compare the two plans - comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + val emailAttribute = UnresolvedAttribute("email") + val hostAttribute = UnresolvedAttribute("host") + val hostExpression = Alias( + Coalesce( + Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal("1")))), "host")() + val expectedPlan = Project( + Seq(emailAttribute, hostAttribute), + Project( + Seq(emailAttribute, hostExpression), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + )) + assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } + test("test parse email expressions parsing filter & sort by age") { + val frame = sql( + s""" + | source = $testTable| parse email '.+@(?.+)' | where age > 45 | sort - age | fields age, email, host ; + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row(76, "frank@sample.org", "sample.org"), + Row(65, "charlie@domain.net", "domain.net"), + Row(55, "bob@test.org", "test.org") + ) + + // Compare the results + assert(results.sameElements(expectedResults)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + // Define the expected logical plan + val emailAttribute = UnresolvedAttribute("email") + val ageAttribute = UnresolvedAttribute("age") + val hostExpression = Alias(Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))), "host")() + + // Define the corrected expected plan + val expectedPlan = Project( + Seq(ageAttribute, emailAttribute, UnresolvedAttribute("host")), + Sort( + Seq(SortOrder(ageAttribute, Descending, NullsLast, Seq.empty)), + global = true, + Filter( + GreaterThan(ageAttribute, Literal(45)), + Project( + Seq(emailAttribute, hostExpression), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + ) + ) + ) + ) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala index aa37ea484..6bdb4bf98 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala @@ -8,7 +8,7 @@ package org.opensearch.flint.spark.ppl import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection.universe.Star import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Coalesce, GreaterThan, Literal, NamedExpression, NullsFirst, RegExpExtract, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Coalesce, Descending, GreaterThan, Literal, NamedExpression, NullsFirst, NullsLast, RegExpExtract, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project, Sort} import org.opensearch.flint.spark.ppl.PlaneUtils.plan @@ -66,6 +66,37 @@ class PPLLogicalPlanParseTranslatorTestSuite assert(compareByString(expectedPlan) === compareByString(logPlan)) } + test("test parse email expression with filter by age and sort by age field") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, + "source = t | parse email '.+@(?.+)' | where age > 45 | sort - age | fields age, email, host", isExplain = false), + context) + + // Define the expected logical plan + val emailAttribute = UnresolvedAttribute("email") + val ageAttribute = UnresolvedAttribute("age") + val hostExpression = Alias(Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))), "host")() + + // Define the corrected expected plan + val expectedPlan = Project( + Seq(ageAttribute, emailAttribute, UnresolvedAttribute("host")), + Sort( + Seq(SortOrder(ageAttribute, Descending, NullsLast, Seq.empty)), + global = true, + Filter( + GreaterThan(ageAttribute, Literal(45)), + Project( + Seq(ageAttribute, emailAttribute, hostExpression), + UnresolvedRelation(Seq("t")) + ) + ) + ) + ) + assert(compareByString(expectedPlan) === compareByString(logPlan)) + } + test("test parse email expression, generate new host field and eval result") { val context = new CatalystPlanContext val logPlan = From 507ead2a99c3ed4e7f592210cb943f2b013ca8ef Mon Sep 17 00:00:00 2001 From: YANGDB Date: Thu, 22 Aug 2024 12:42:29 -0700 Subject: [PATCH 6/9] scalafmtAll fixes Signed-off-by: YANGDB Signed-off-by: YANGDB --- .../flint/spark/FlintSparkSuite.scala | 9 +- .../spark/ppl/FlintSparkPPLParseITSuite.scala | 38 ++++----- .../sql/ppl/CatalystPlanContext.java | 18 ++++ .../sql/ppl/CatalystQueryPlanVisitor.java | 5 +- ...LLogicalPlanParseTranslatorTestSuite.scala | 83 ++++++++++--------- 5 files changed, 83 insertions(+), 70 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index bb4d53a59..e0f42cfa7 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -101,8 +101,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit } protected def createPartitionedGrokEmailTable(testTable: String): Unit = { - spark.sql( - s""" + spark.sql(s""" | CREATE TABLE $testTable | ( | name STRING, @@ -127,12 +126,10 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit ("Grace", 41, "grace@demo.net", "404 Birch St, Chicago", 2023, 4), ("Hank", 32, "hank@demonstration.com", "505 Spruce St, Miami", 2023, 5), ("Ivy", 9, "ivy@examples.org", "606 Fir St, Denver", 2023, 4), - ("Jack", 12, "jack@sample.net", "707 Ash St, Seattle", 2023, 5) - ) + ("Jack", 12, "jack@sample.net", "707 Ash St, Seattle", 2023, 5)) data.foreach { case (name, age, email, street_address, year, month) => - spark.sql( - s""" + spark.sql(s""" | INSERT INTO $testTable | PARTITION (year=$year, month=$month) | VALUES ('$name', $age, '$email', '$street_address') diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala index f0578cfab..39dd4e643 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala @@ -5,12 +5,13 @@ package org.opensearch.flint.spark.ppl +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Coalesce, Descending, GreaterThan, Literal, NullsLast, RegExpExtract, SortOrder} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort} import org.apache.spark.sql.streaming.StreamTest -import org.apache.spark.sql.{AnalysisException, QueryTest, Row} -import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq class FlintSparkPPLParseITSuite extends QueryTest @@ -38,8 +39,7 @@ class FlintSparkPPLParseITSuite } test("test parse email expressions parsing") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| parse email '.+@(?.+)' | fields email, host ; | """.stripMargin) @@ -57,8 +57,7 @@ class FlintSparkPPLParseITSuite Row("jack@sample.net", "sample.net"), Row("eve@examples.com", "examples.com"), Row("ivy@examples.org", "examples.org"), - Row("bob@test.org", "test.org") - ) + Row("bob@test.org", "test.org")) // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](_.getAs[String](0)) @@ -70,20 +69,18 @@ class FlintSparkPPLParseITSuite val emailAttribute = UnresolvedAttribute("email") val hostAttribute = UnresolvedAttribute("host") val hostExpression = Alias( - Coalesce( - Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal("1")))), "host")() + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal("1")))), + "host")() val expectedPlan = Project( Seq(emailAttribute, hostAttribute), Project( - Seq(emailAttribute, hostExpression), - UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) - )) + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))) assert(compareByString(expectedPlan) === compareByString(logicalPlan)) } test("test parse email expressions parsing filter & sort by age") { - val frame = sql( - s""" + val frame = sql(s""" | source = $testTable| parse email '.+@(?.+)' | where age > 45 | sort - age | fields age, email, host ; | """.stripMargin) @@ -93,8 +90,7 @@ class FlintSparkPPLParseITSuite val expectedResults: Array[Row] = Array( Row(76, "frank@sample.org", "sample.org"), Row(65, "charlie@domain.net", "domain.net"), - Row(55, "bob@test.org", "test.org") - ) + Row(55, "bob@test.org", "test.org")) // Compare the results assert(results.sameElements(expectedResults)) @@ -104,7 +100,9 @@ class FlintSparkPPLParseITSuite // Define the expected logical plan val emailAttribute = UnresolvedAttribute("email") val ageAttribute = UnresolvedAttribute("age") - val hostExpression = Alias(Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))), "host")() + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))), + "host")() // Define the corrected expected plan val expectedPlan = Project( @@ -115,12 +113,8 @@ class FlintSparkPPLParseITSuite Filter( GreaterThan(ageAttribute, Literal(45)), Project( - Seq(emailAttribute, hostExpression), - UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) - ) - ) - ) - ) + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))))) comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java index b63dceff2..e262acbde 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystPlanContext.java @@ -34,6 +34,10 @@ * The context used for Catalyst logical plan. */ public class CatalystPlanContext { + /** + * Catalyst relations list + **/ + private List projectedFields = new ArrayList<>(); /** * Catalyst relations list **/ @@ -65,6 +69,10 @@ public List getRelations() { return relations; } + public List getProjectedFields() { + return projectedFields; + } + public LogicalPlan getPlan() { if (this.planBranches.size() == 1) { return planBranches.peek(); @@ -113,6 +121,16 @@ public LogicalPlan withRelation(UnresolvedRelation relation) { this.relations.add(relation); return with(relation); } + /** + * append projected fields + * + * @param projectedFields + * @return + */ + public LogicalPlan withProjectedFields(List projectedFields) { + this.projectedFields.addAll(projectedFields); + return getPlan(); + } /** * append plan with evolving plans branches diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 0f470185c..cf20ef8f5 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -240,7 +240,7 @@ public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) { @Override public LogicalPlan visitProject(Project node, CatalystPlanContext context) { LogicalPlan child = node.getChild().get(0).accept(this, context); - List expressionList = visitExpressionList(node.getProjectList(), context); + context.withProjectedFields(visitExpressionList(node.getProjectList(), context)); // Create a projection list from the existing expressions Seq projectList = seq(context.getNamedParseExpressions()); @@ -316,6 +316,9 @@ private LogicalPlan visitParseCommand(Parse node, Expression sourceField, ParseM Option.empty(), seq(new java.util.ArrayList()))); } + // Create an UnresolvedStar for all-fields projection (possible external wrapping projection that may include additional fields) + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + // extract all fields to project with Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); // build the plan with the projection step LogicalPlan child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala index 6bdb4bf98..9a90cb423 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala @@ -5,16 +5,17 @@ package org.opensearch.flint.spark.ppl +import org.opensearch.flint.spark.ppl.PlaneUtils.plan +import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq +import org.scalatest.matchers.should.Matchers + import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection.universe.Star import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Coalesce, Descending, GreaterThan, Literal, NamedExpression, NullsFirst, NullsLast, RegExpExtract, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project, Sort} -import org.opensearch.flint.spark.ppl.PlaneUtils.plan -import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq -import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} -import org.scalatest.matchers.should.Matchers class PPLLogicalPlanParseTranslatorTestSuite extends SparkFunSuite @@ -29,20 +30,22 @@ class PPLLogicalPlanParseTranslatorTestSuite val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, "source=t | parse email '.+@(?.+)' | fields email, host", isExplain = false), + plan( + pplParser, + "source=t | parse email '.+@(?.+)' | fields email, host", + isExplain = false), context) val emailAttribute = UnresolvedAttribute("email") val hostAttribute = UnresolvedAttribute("host") val hostExpression = Alias( - Coalesce( - Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal("1")))), "host")() + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal("1")))), + "host")() val expectedPlan = Project( Seq(emailAttribute, hostAttribute), Project( - Seq(emailAttribute, hostExpression), - UnresolvedRelation(Seq("t")) - )) + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("t")))) assert(compareByString(expectedPlan) === compareByString(logPlan)) } @@ -52,17 +55,16 @@ class PPLLogicalPlanParseTranslatorTestSuite planTransformer.visit( plan(pplParser, "source=t | parse email '.+@(?.+)' | fields email", false), context) - + val emailAttribute = UnresolvedAttribute("email") val hostExpression = Alias( - Coalesce( - Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal("1")))), "email")() + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal("1")))), + "email")() val expectedPlan = Project( Seq(emailAttribute), Project( - Seq(emailAttribute, hostExpression), - UnresolvedRelation(Seq("t")) - )) + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("t")))) assert(compareByString(expectedPlan) === compareByString(logPlan)) } @@ -70,14 +72,18 @@ class PPLLogicalPlanParseTranslatorTestSuite val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, - "source = t | parse email '.+@(?.+)' | where age > 45 | sort - age | fields age, email, host", isExplain = false), + plan( + pplParser, + "source = t | parse email '.+@(?.+)' | where age > 45 | sort - age | fields age, email, host", + isExplain = false), context) // Define the expected logical plan val emailAttribute = UnresolvedAttribute("email") val ageAttribute = UnresolvedAttribute("age") - val hostExpression = Alias(Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))), "host")() + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))), + "host")() // Define the corrected expected plan val expectedPlan = Project( @@ -88,12 +94,8 @@ class PPLLogicalPlanParseTranslatorTestSuite Filter( GreaterThan(ageAttribute, Literal(45)), Project( - Seq(ageAttribute, emailAttribute, hostExpression), - UnresolvedRelation(Seq("t")) - ) - ) - ) - ) + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("t")))))) assert(compareByString(expectedPlan) === compareByString(logPlan)) } @@ -101,7 +103,10 @@ class PPLLogicalPlanParseTranslatorTestSuite val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, "source=t | parse email '.+@(?.+)' | eval eval_result=1 | fields host, eval_result", false), + plan( + pplParser, + "source=t | parse email '.+@(?.+)' | eval eval_result=1 | fields host, eval_result", + false), context) val emailAttribute = UnresolvedAttribute("email") @@ -119,11 +124,8 @@ class PPLLogicalPlanParseTranslatorTestSuite Project( Seq(UnresolvedStar(None), evalResultExpression), Project( - Seq(emailAttribute, hostExpression), - UnresolvedRelation(Seq("t")) - ) - ) - ) + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("t"))))) assert(compareByString(expectedPlan) === compareByString(logPlan)) } @@ -131,7 +133,10 @@ class PPLLogicalPlanParseTranslatorTestSuite val context = new CatalystPlanContext val logPlan = planTransformer.visit( - plan(pplParser, "source=t | parse address '(?\\d+) (?.+)' | where streetNumber > 500 | sort num(streetNumber) | fields streetNumber, street", false), + plan( + pplParser, + "source=t | parse address '(?\\d+) (?.+)' | where streetNumber > 500 | sort num(streetNumber) | fields streetNumber, street", + false), context) val addressAttribute = UnresolvedAttribute("address") @@ -140,8 +145,7 @@ class PPLLogicalPlanParseTranslatorTestSuite val streetNumberExpression = Alias( Coalesce(Seq(RegExpExtract(addressAttribute, Literal("(\\d+) (.+)"), Literal("1")))), - "streetNumber" - )() + "streetNumber")() val streetExpression = Alias( Coalesce(Seq(RegExpExtract(addressAttribute, Literal("(\\d+) (.+)"), Literal("2")))), @@ -155,12 +159,9 @@ class PPLLogicalPlanParseTranslatorTestSuite Filter( GreaterThan(streetNumberAttribute, Literal(500)), Project( - Seq(addressAttribute, streetNumberExpression, streetExpression), - UnresolvedRelation(Seq("t")) - ) - ) - ) - ) + Seq(addressAttribute, streetNumberExpression, streetExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("t")))))) assert(compareByString(expectedPlan) === compareByString(logPlan)) - }} + } +} From 84d247d0bc19c92ea5b1bfea9f5b51b91ba237ad Mon Sep 17 00:00:00 2001 From: YANGDB Date: Thu, 22 Aug 2024 20:37:15 -0700 Subject: [PATCH 7/9] fix depended top/rare issues Signed-off-by: YANGDB Signed-off-by: YANGDB --- .../flint/spark/FlintSparkSuite.scala | 10 +- .../spark/ppl/FlintSparkPPLParseITSuite.scala | 104 +++++++++++++++++- .../ppl/FlintSparkPPLTopAndRareITSuite.scala | 104 ++++++++++++++---- .../sql/ppl/CatalystQueryPlanVisitor.java | 2 +- .../opensearch/sql/ppl/parser/AstBuilder.java | 18 +-- ...LLogicalPlanParseTranslatorTestSuite.scala | 76 ++++++++++++- ...TopAndRareQueriesTranslatorTestSuite.scala | 92 +++++++++++++--- 7 files changed, 350 insertions(+), 56 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index e0f42cfa7..a9bbac710 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -125,7 +125,7 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit ("Frank", 76, "frank@sample.org", "303 Cedar St, Austin", 2023, 5), ("Grace", 41, "grace@demo.net", "404 Birch St, Chicago", 2023, 4), ("Hank", 32, "hank@demonstration.com", "505 Spruce St, Miami", 2023, 5), - ("Ivy", 9, "ivy@examples.org", "606 Fir St, Denver", 2023, 4), + ("Ivy", 9, "ivy@examples.com", "606 Fir St, Denver", 2023, 4), ("Jack", 12, "jack@sample.net", "707 Ash St, Seattle", 2023, 5)) data.foreach { case (name, age, email, street_address, year, month) => @@ -271,9 +271,13 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | VALUES ('Jake', 'Engineer', 'England' , 100000), | ('Hello', 'Artist', 'USA', 70000), | ('John', 'Doctor', 'Canada', 120000), - | ('David', 'Doctor', 'USA', 120000), + | ('Rachel', 'Doctor', 'Canada', 220000), + | ('Henry', 'Doctor', 'Canada', 220000), + | ('David', 'Engineer', 'USA', 320000), + | ('Barty', 'Engineer', 'USA', 120000), | ('David', 'Unemployed', 'Canada', 0), - | ('Jane', 'Scientist', 'Canada', 90000) + | ('Jane', 'Scientist', 'Canada', 90000), + | ('Philip', 'Scientist', 'Canada', 190000) | """.stripMargin) } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala index 39dd4e643..388de3d31 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLParseITSuite.scala @@ -5,12 +5,14 @@ package org.opensearch.flint.spark.ppl +import scala.reflect.internal.Reporter.Count + import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Coalesce, Descending, GreaterThan, Literal, NullsLast, RegExpExtract, SortOrder} -import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, LogicalPlan, Project, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, LogicalPlan, Project, Sort} import org.apache.spark.sql.streaming.StreamTest class FlintSparkPPLParseITSuite @@ -56,7 +58,7 @@ class FlintSparkPPLParseITSuite Row("grace@demo.net", "demo.net"), Row("jack@sample.net", "sample.net"), Row("eve@examples.com", "examples.com"), - Row("ivy@examples.org", "examples.org"), + Row("ivy@examples.com", "examples.com"), Row("bob@test.org", "test.org")) // Compare the results @@ -117,4 +119,102 @@ class FlintSparkPPLParseITSuite UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")))))) comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } + + test("test parse email expressions and group by count host ") { + val frame = sql(s""" + | source = $testTable| parse email '.+@(?.+)' | stats count() by host + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array( + Row(1L, "demonstration.com"), + Row(1L, "example.com"), + Row(1L, "domain.net"), + Row(1L, "anotherdomain.com"), + Row(1L, "sample.org"), + Row(1L, "demo.net"), + Row(1L, "sample.net"), + Row(2L, "examples.com"), + Row(1L, "test.org")) + + // Sort both the results and the expected results + implicit val rowOrdering: Ordering[Row] = Ordering.by(r => (r.getLong(0), r.getString(1))) + assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val emailAttribute = UnresolvedAttribute("email") + val hostAttribute = UnresolvedAttribute("host") + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))), + "host")() + + // Define the corrected expected plan + val expectedPlan = Project( + Seq(UnresolvedStar(None)), // Matches the '*' in the Project + Aggregate( + Seq(Alias(hostAttribute, "host")()), // Group by 'host' + Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "count()")(), + Alias(hostAttribute, "host")()), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))))) + // Compare the logical plans + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + + test("test parse email expressions and top count_host ") { + val frame = sql(s""" + | source = $testTable| parse email '.+@(?.+)' | top 1 host + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + // Define the expected results + val expectedResults: Array[Row] = Array(Row(2L, "examples.com")) + + // Sort both the results and the expected results + implicit val rowOrdering: Ordering[Row] = Ordering.by(r => (r.getLong(0), r.getString(1))) + assert(results.sorted.sameElements(expectedResults.sorted)) + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + + val emailAttribute = UnresolvedAttribute("email") + val hostAttribute = UnresolvedAttribute("host") + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))), + "host")() + + val sortedPlan = Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(hostAttribute), isDistinct = false), + "count_host")(), + Descending, + NullsLast, + Seq.empty)), + global = true, + Aggregate( + Seq(hostAttribute), + Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(hostAttribute), isDistinct = false), + "count_host")(), + hostAttribute), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))))) + // Define the corrected expected plan + val expectedPlan = Project( + Seq(UnresolvedStar(None)), // Matches the '*' in the Project + GlobalLimit(Literal(1), LocalLimit(Literal(1), sortedPlan))) + // Compare the logical plans + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } } diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala index 09307aa44..0e50b9845 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala @@ -21,11 +21,13 @@ class FlintSparkPPLTopAndRareITSuite /** Test table and index name */ private val testTable = "spark_catalog.default.flint_ppl_test" + private val newTestTable = "spark_catalog.default.new_flint_ppl_test" override def beforeAll(): Unit = { super.beforeAll() - // Create test table + // Create test tables + createOccupationTable(newTestTable) createPartitionedMultiRowAddressTable(testTable) } @@ -61,7 +63,7 @@ class FlintSparkPPLTopAndRareITSuite val aggregateExpressions = Seq( Alias( UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), - "count(address)")(), + "count_address")(), addressField) val aggregatePlan = Aggregate( @@ -70,11 +72,16 @@ class FlintSparkPPLTopAndRareITSuite UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("address"), Descending)), + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Ascending)), global = true, aggregatePlan) val expectedPlan = Project(projectList, sortedPlan) - comparePlans(expectedPlan, logicalPlan, false) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } test("create ppl rare address by age field query test") { @@ -101,7 +108,7 @@ class FlintSparkPPLTopAndRareITSuite val countExpr = Alias( UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), - "count(address)")() + "count_address")() val aggregateExpressions = Seq(countExpr, addressField, ageAlias) val aggregatePlan = @@ -112,7 +119,12 @@ class FlintSparkPPLTopAndRareITSuite val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("address"), Descending)), + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Ascending)), global = true, aggregatePlan) @@ -146,7 +158,7 @@ class FlintSparkPPLTopAndRareITSuite val aggregateExpressions = Seq( Alias( UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), - "count(address)")(), + "count_address")(), addressField) val aggregatePlan = Aggregate( @@ -155,17 +167,66 @@ class FlintSparkPPLTopAndRareITSuite UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))) val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("address"), Ascending)), + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Descending)), global = true, aggregatePlan) val expectedPlan = Project(projectList, sortedPlan) - comparePlans(expectedPlan, logicalPlan, false) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } - test("create ppl top 3 countries by occupation field query test") { - val newTestTable = "spark_catalog.default.new_flint_ppl_test" - createOccupationTable(newTestTable) + test("create ppl top 3 countries query test") { + val frame = sql(s""" + | source = $newTestTable| top 3 country + | """.stripMargin) + + // Retrieve the results + val results: Array[Row] = frame.collect() + assert(results.length == 3) + + val expectedRows = Set(Row(6, "Canada"), Row(3, "USA"), Row(1, "England")) + val actualRows = results.take(3).toSet + + // Compare the sets + assert( + actualRows == expectedRows, + s"The first two results do not match the expected rows. Expected: $expectedRows, Actual: $actualRows") + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val countryField = UnresolvedAttribute("country") + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")() + val aggregateExpressions = Seq(countExpr, countryField) + val aggregatePlan = + Aggregate( + Seq(countryField), + aggregateExpressions, + UnresolvedRelation(Seq("spark_catalog", "default", "new_flint_ppl_test"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")(), + Descending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } + test("create ppl top 2 countries by occupation field query test") { val frame = sql(s""" | source = $newTestTable| top 3 country by occupation | """.stripMargin) @@ -174,10 +235,8 @@ class FlintSparkPPLTopAndRareITSuite val results: Array[Row] = frame.collect() assert(results.length == 3) - val expectedRows = Set( - Row(1, "Canada", "Doctor"), - Row(1, "Canada", "Scientist"), - Row(1, "Canada", "Unemployed")) + val expectedRows = + Set(Row(3, "Canada", "Doctor"), Row(2, "Canada", "Scientist"), Row(2, "USA", "Engineer")) val actualRows = results.take(3).toSet // Compare the sets @@ -187,14 +246,13 @@ class FlintSparkPPLTopAndRareITSuite // Retrieve the logical plan val logicalPlan: LogicalPlan = frame.queryExecution.logical - val countryField = UnresolvedAttribute("country") val occupationField = UnresolvedAttribute("occupation") val occupationFieldAlias = Alias(occupationField, "occupation")() val countExpr = Alias( UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), - "count(country)")() + "count_country")() val aggregateExpressions = Seq(countExpr, countryField, occupationFieldAlias) val aggregatePlan = Aggregate( @@ -204,13 +262,19 @@ class FlintSparkPPLTopAndRareITSuite val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("country"), Ascending)), + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")(), + Descending)), global = true, aggregatePlan) val planWithLimit = GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) - comparePlans(expectedPlan, logicalPlan, false) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index cf20ef8f5..6caaec839 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -206,7 +206,7 @@ public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext contex // set sort direction according to command type (`rare` is Asc, `top` is Desc, default to Asc) List sortDirections = new ArrayList<>(); - sortDirections.add(node instanceof RareAggregation ? Descending$.MODULE$ : Ascending$.MODULE$); + sortDirections.add(node instanceof RareAggregation ? Ascending$.MODULE$ : Descending$.MODULE$); if (!node.getSortExprList().isEmpty()) { visitExpressionList(node.getSortExprList(), context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 7d91bbb7a..fdb11c342 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -278,12 +278,11 @@ public UnresolvedPlan visitPatternsCommand(OpenSearchPPLParser.PatternsCommandCo public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); - ImmutableList.Builder sortListBuilder = new ImmutableList.Builder<>(); ctx.fieldList().fieldExpression().forEach(field -> { UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); String name = field.qualifiedName().getText(); - Alias alias = new Alias("count("+name+")", aggExpression); + Alias alias = new Alias("count_"+name, aggExpression); aggListBuilder.add(alias); // group by the `field-list` as the mandatory groupBy fields groupListBuilder.add(internalVisitExpression(field)); @@ -305,16 +304,12 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) .collect(Collectors.toList())) .orElse(emptyList()) ); - //build the sort fields - ctx.fieldList().fieldExpression().forEach(field -> { - sortListBuilder.add(internalVisitExpression(field)); - }); UnresolvedExpression unresolvedPlan = (ctx.number != null ? internalVisitExpression(ctx.number) : null); TopAggregation aggregation = new TopAggregation( Optional.ofNullable((Literal) unresolvedPlan), aggListBuilder.build(), - sortListBuilder.build(), + aggListBuilder.build(), groupListBuilder.build()); return aggregation; } @@ -324,12 +319,11 @@ public UnresolvedPlan visitTopCommand(OpenSearchPPLParser.TopCommandContext ctx) public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ctx) { ImmutableList.Builder aggListBuilder = new ImmutableList.Builder<>(); ImmutableList.Builder groupListBuilder = new ImmutableList.Builder<>(); - ImmutableList.Builder sortListBuilder = new ImmutableList.Builder<>(); ctx.fieldList().fieldExpression().forEach(field -> { UnresolvedExpression aggExpression = new AggregateFunction("count",internalVisitExpression(field), Collections.singletonList(new Argument("countParam", new Literal(1, DataType.INTEGER)))); String name = field.qualifiedName().getText(); - Alias alias = new Alias("count("+name+")", aggExpression); + Alias alias = new Alias("count_"+name, aggExpression); aggListBuilder.add(alias); // group by the `field-list` as the mandatory groupBy fields groupListBuilder.add(internalVisitExpression(field)); @@ -351,14 +345,10 @@ public UnresolvedPlan visitRareCommand(OpenSearchPPLParser.RareCommandContext ct .collect(Collectors.toList())) .orElse(emptyList()) ); - //build the sort fields - ctx.fieldList().fieldExpression().forEach(field -> { - sortListBuilder.add(internalVisitExpression(field)); - }); RareAggregation aggregation = new RareAggregation( aggListBuilder.build(), - sortListBuilder.build(), + aggListBuilder.build(), groupListBuilder.build()); return aggregation; } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala index 9a90cb423..cfc3d9725 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanParseTranslatorTestSuite.scala @@ -12,10 +12,10 @@ import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.ScalaReflection.universe.Star -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Coalesce, Descending, GreaterThan, Literal, NamedExpression, NullsFirst, NullsLast, RegExpExtract, SortOrder} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Filter, Project, Sort} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Filter, GlobalLimit, LocalLimit, Project, Sort} class PPLLogicalPlanParseTranslatorTestSuite extends SparkFunSuite @@ -164,4 +164,76 @@ class PPLLogicalPlanParseTranslatorTestSuite assert(compareByString(expectedPlan) === compareByString(logPlan)) } + + test("test parse email expressions and group by count host ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | parse email '.+@(?.+)' | stats count() by host", false), + context) + + val emailAttribute = UnresolvedAttribute("email") + val hostAttribute = UnresolvedAttribute("host") + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))), + "host")() + + // Define the corrected expected plan + val expectedPlan = Project( + Seq(UnresolvedStar(None)), // Matches the '*' in the Project + Aggregate( + Seq(Alias(hostAttribute, "host")()), // Group by 'host' + Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(UnresolvedStar(None)), isDistinct = false), + "count()")(), + Alias(hostAttribute, "host")()), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("t"))))) + + // Compare the logical plans + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + + test("test parse email expressions and top count_host ") { + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=t | parse email '.+@(?.+)' | top 1 host", false), + context) + + val emailAttribute = UnresolvedAttribute("email") + val hostAttribute = UnresolvedAttribute("host") + val hostExpression = Alias( + Coalesce(Seq(RegExpExtract(emailAttribute, Literal(".+@(.+)"), Literal(1)))), + "host")() + + val sortedPlan = Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(hostAttribute), isDistinct = false), + "count_host")(), + Descending, + NullsLast, + Seq.empty)), + global = true, + Aggregate( + Seq(hostAttribute), + Seq( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(hostAttribute), isDistinct = false), + "count_host")(), + hostAttribute), + Project( + Seq(emailAttribute, hostExpression, UnresolvedStar(None)), + UnresolvedRelation(Seq("t"))))) + // Define the corrected expected plan + val expectedPlan = Project( + Seq(UnresolvedStar(None)), // Matches the '*' in the Project + GlobalLimit(Literal(1), LocalLimit(Literal(1), sortedPlan))) + // Compare the logical plans + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala index 5bd5da28c..c6e5a7f38 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite.scala @@ -30,7 +30,9 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext val logPlan = - planTransformer.visit(plan(pplParser, "source=accounts | rare address", false), context) + planTransformer.visit( + plan(pplParser, "source=accounts | rare address", isExplain = false), + context) val addressField = UnresolvedAttribute("address") val tableRelation = UnresolvedRelation(Seq("accounts")) @@ -39,7 +41,7 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val aggregateExpressions = Seq( Alias( UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), - "count(address)")(), + "count_address")(), addressField) val aggregatePlan = @@ -47,11 +49,16 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("address"), Descending)), + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Ascending)), global = true, aggregatePlan) val expectedPlan = Project(projectList, sortedPlan) - comparePlans(expectedPlan, logPlan, false) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) } test("test simple rare command with a by field test") { @@ -59,7 +66,7 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val context = new CatalystPlanContext val logicalPlan = planTransformer.visit( - plan(pplParser, "source=accounts | rare address by age", false), + plan(pplParser, "source=accounts | rare address by age", isExplain = false), context) // Retrieve the logical plan // Define the expected logical plan @@ -71,7 +78,7 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val countExpr = Alias( UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), - "count(address)")() + "count_address")() val aggregateExpressions = Seq(countExpr, addressField, ageAlias) val aggregatePlan = @@ -82,19 +89,26 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("address"), Descending)), + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Ascending)), global = true, aggregatePlan) val expectedPlan = Project(projectList, sortedPlan) - comparePlans(expectedPlan, logicalPlan, false) + comparePlans(expectedPlan, logicalPlan, checkAnalysis = false) } test("test simple top command with a single field") { // if successful build ppl logical plan and translate to catalyst logical plan val context = new CatalystPlanContext val logPlan = - planTransformer.visit(plan(pplParser, "source=accounts | top address", false), context) + planTransformer.visit( + plan(pplParser, "source=accounts | top address", isExplain = false), + context) val addressField = UnresolvedAttribute("address") val tableRelation = UnresolvedRelation(Seq("accounts")) @@ -103,7 +117,7 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val aggregateExpressions = Seq( Alias( UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), - "count(address)")(), + "count_address")(), addressField) val aggregatePlan = @@ -111,11 +125,16 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("address"), Ascending)), + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Descending)), global = true, aggregatePlan) val expectedPlan = Project(projectList, sortedPlan) - comparePlans(expectedPlan, logPlan, false) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) } test("test simple top 1 command by age field") { @@ -132,7 +151,7 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val countExpr = Alias( UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), - "count(address)")() + "count_address")() val aggregateExpressions = Seq(countExpr, addressField, ageAlias) val aggregatePlan = Aggregate( @@ -142,7 +161,12 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val sortedPlan: LogicalPlan = Sort( - Seq(SortOrder(UnresolvedAttribute("address"), Ascending)), + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(addressField), isDistinct = false), + "count_address")(), + Descending)), global = true, aggregatePlan) @@ -151,4 +175,44 @@ class PPLLogicalPlanTopAndRareQueriesTranslatorTestSuite val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) comparePlans(expectedPlan, logPlan, false) } + + test("create ppl top 3 countries by occupation field query test") { + // if successful build ppl logical plan and translate to catalyst logical plan + val context = new CatalystPlanContext + val logPlan = + planTransformer.visit( + plan(pplParser, "source=accounts | top 3 country by occupation", false), + context) + + val countryField = UnresolvedAttribute("country") + val occupationField = UnresolvedAttribute("occupation") + val occupationFieldAlias = Alias(occupationField, "occupation")() + + val countExpr = Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")() + val aggregateExpressions = Seq(countExpr, countryField, occupationFieldAlias) + val aggregatePlan = + Aggregate( + Seq(countryField, occupationFieldAlias), + aggregateExpressions, + UnresolvedRelation(Seq("accounts"))) + + val sortedPlan: LogicalPlan = + Sort( + Seq( + SortOrder( + Alias( + UnresolvedFunction(Seq("COUNT"), Seq(countryField), isDistinct = false), + "count_country")(), + Descending)), + global = true, + aggregatePlan) + + val planWithLimit = + GlobalLimit(Literal(3), LocalLimit(Literal(3), sortedPlan)) + val expectedPlan = Project(Seq(UnresolvedStar(None)), planWithLimit) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) + } + } From ef889f6cddf410da4ffe10374e033829354f5b35 Mon Sep 17 00:00:00 2001 From: YANGDB Date: Thu, 22 Aug 2024 20:54:13 -0700 Subject: [PATCH 8/9] fix depended top/rare issues update readme with command Signed-off-by: YANGDB Signed-off-by: YANGDB --- .../flint/spark/FlintSparkSuite.scala | 29 +++++++++++++++++++ .../ppl/FlintSparkPPLTopAndRareITSuite.scala | 2 +- ppl-spark-integration/README.md | 9 ++++++ 3 files changed, 39 insertions(+), 1 deletion(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala index a9bbac710..3f843dbe4 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/FlintSparkSuite.scala @@ -264,6 +264,35 @@ trait FlintSparkSuite extends QueryTest with FlintSuite with OpenSearchSuite wit | ) |""".stripMargin) + // Insert data into the new table + sql(s""" + | INSERT INTO $testTable + | PARTITION (year=2023, month=4) + | VALUES ('Jake', 'Engineer', 'England' , 100000), + | ('Hello', 'Artist', 'USA', 70000), + | ('John', 'Doctor', 'Canada', 120000), + | ('David', 'Doctor', 'USA', 120000), + | ('David', 'Unemployed', 'Canada', 0), + | ('Jane', 'Scientist', 'Canada', 90000) + | """.stripMargin) + } + + protected def createOccupationTopRareTable(testTable: String): Unit = { + sql(s""" + | CREATE TABLE $testTable + | ( + | name STRING, + | occupation STRING, + | country STRING, + | salary INT + | ) + | USING $tableType $tableOptions + | PARTITIONED BY ( + | year INT, + | month INT + | ) + |""".stripMargin) + // Insert data into the new table sql(s""" | INSERT INTO $testTable diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala index 0e50b9845..f10b6e2f5 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLTopAndRareITSuite.scala @@ -27,7 +27,7 @@ class FlintSparkPPLTopAndRareITSuite super.beforeAll() // Create test tables - createOccupationTable(newTestTable) + createOccupationTopRareTable(newTestTable) createPartitionedMultiRowAddressTable(testTable) } diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index bc8a96c52..972a1bebe 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -306,6 +306,15 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source=accounts | top 1 gender` - `source=accounts | top 1 age by gender` +**Parse** +- `source=accounts | top gender` +- `source=accounts | parse email '.+@(?.+)' | fields email, host ` +- `source=accounts | parse email '.+@(?.+)' | top 1 host ` +- `source=accounts | parse email '.+@(?.+)' | stats count() by host` +- `source=accounts | parse email '.+@(?.+)' | eval eval_result=1 | fields host, eval_result` +- `source=accounts | parse email '.+@(?.+)' | where age > 45 | sort - age | fields age, email, host` +- `source=accounts | parse address '(?\d+) (?.+)' | where streetNumber > 500 | sort num(streetNumber) | fields streetNumber, street` + > For additional details on PPL commands - view [PPL Commands Docs](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/index.rst) From e7a8ef7ec6b1c4633d2647a077e3102654ddb69f Mon Sep 17 00:00:00 2001 From: YANGDB Date: Thu, 22 Aug 2024 21:00:29 -0700 Subject: [PATCH 9/9] fix depended top/rare issues update readme with command Signed-off-by: YANGDB Signed-off-by: YANGDB --- ppl-spark-integration/README.md | 1 - 1 file changed, 1 deletion(-) diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index 972a1bebe..24639e444 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -307,7 +307,6 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source=accounts | top 1 age by gender` **Parse** -- `source=accounts | top gender` - `source=accounts | parse email '.+@(?.+)' | fields email, host ` - `source=accounts | parse email '.+@(?.+)' | top 1 host ` - `source=accounts | parse email '.+@(?.+)' | stats count() by host`