From 548a0a451671ca8ea0000d8578fde7842549c811 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Fri, 9 Aug 2024 08:57:00 +0800 Subject: [PATCH 1/6] Translate PPL Command Part 2: allowedDuplication>1 Signed-off-by: Lantao Jin --- .../spark/ppl/FlintSparkPPLDedupITSuite.scala | 8 +- .../sql/ppl/CatalystQueryPlanVisitor.java | 93 ++- .../ppl/CatalystQueryPlanVisitor.java.orig | 595 ++++++++++++++++++ .../sql/ppl/utils/WindowSpecTransformer.java | 35 +- ...LLogicalPlanDedupTranslatorTestSuite.scala | 147 ++++- 5 files changed, 855 insertions(+), 23 deletions(-) create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java.orig diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupITSuite.scala index 06c90527d..8ce078176 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupITSuite.scala @@ -187,7 +187,7 @@ class FlintSparkPPLDedupITSuite assert(ex.getMessage.contains("Consecutive deduplication is not supported")) } - ignore("test dedupe 2 name") { + test("test dedupe 2 name") { val frame = sql(s""" | source = $testTable| dedup 2 name | fields name | """.stripMargin) @@ -200,7 +200,7 @@ class FlintSparkPPLDedupITSuite assert(results.sorted.sameElements(expectedResults.sorted)) } - ignore("test dedupe 2 name, category") { + test("test dedupe 2 name, category") { val frame = sql(s""" | source = $testTable| dedup 2 name, category | fields name, category | """.stripMargin) @@ -225,7 +225,7 @@ class FlintSparkPPLDedupITSuite assert(results.sorted.sameElements(expectedResults.sorted)) } - ignore("test dedupe 2 name KEEPEMPTY=true") { + test("test dedupe 2 name KEEPEMPTY=true") { val frame = sql(s""" | source = $testTable| dedup 2 name KEEPEMPTY=true | fields name, category | """.stripMargin) @@ -259,7 +259,7 @@ class FlintSparkPPLDedupITSuite .sameElements(expectedResults.sorted.map(_.getAs[String](0)))) } - ignore("test dedupe 2 name, category KEEPEMPTY=true") { + test("test dedupe 2 name, category KEEPEMPTY=true") { val frame = sql(s""" | source = $testTable| dedup 2 name, category KEEPEMPTY=true | fields name, category | """.stripMargin) diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 812cbea82..294ac5226 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -10,10 +10,12 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns; import org.apache.spark.sql.catalyst.plans.logical.Deduplicate; import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; import org.apache.spark.sql.catalyst.plans.logical.Limit; @@ -64,6 +66,7 @@ import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; import org.opensearch.sql.ppl.utils.SortUtils; +import org.opensearch.sql.ppl.utils.WindowSpecTransformer; import scala.Option; import scala.Option$; import scala.collection.Seq; @@ -318,13 +321,14 @@ public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { // adding Aggregate operator could achieve better performance. if (allowedDuplication == 1) { if (keepEmpty) { + // | dedup a, b keepempty=true // Union // :- Deduplicate ['a, 'b] // : +- Filter (isnotnull('a) AND isnotnull('b) - // : +- Project + // : +- ... // : +- UnresolvedRelation // +- Filter (isnull('a) OR isnull('a)) - // +- Project + // +- ... // +- UnresolvedRelation context.apply(p -> { @@ -339,9 +343,10 @@ public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { }); return context.getPlan(); } else { + // | dedup a, b keepempty=false // Deduplicate ['a, 'b] // +- Filter (isnotnull('a) AND isnotnull('b)) - // +- Project + // +- ... // +- UnresolvedRelation Expression isNotNullExpr = buildIsNotNullFilterExpression(node, context); @@ -350,8 +355,86 @@ public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { return context.apply(p -> new Deduplicate(dedupFields, p)); } } else { - // TODO - throw new UnsupportedOperationException("Number of duplicate events greater than 1 is not supported"); + if (keepEmpty) { + // | dedup 2 a, b keepempty=true + // Union + //:- DataFrameDropColumns('_row_number_) + //: +- Filter ('_row_number_ <= 2) + //: +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST] + //: +- Filter (isnotnull('a) AND isnotnull('b)) + //: +- ... + //: +- UnresolvedRelation + //+- Filter (isnull('a) OR isnull('b)) + // +- ... + // +- UnresolvedRelation + + context.apply(p -> { + // Build isnull Filter for right + Expression isNullExpr = buildIsNullFilterExpression(node, context); + LogicalPlan right = new org.apache.spark.sql.catalyst.plans.logical.Filter(isNullExpr, p); + + // Build isnotnull Filter + Expression isNotNullExpr = buildIsNotNullFilterExpression(node, context); + LogicalPlan isNotNullFilter = new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p); + + // Build Window + visitFieldList(node.getFields(), context); + Seq partitionSpec = context.retainAllNamedParseExpressions(exp -> exp); + visitFieldList(node.getFields(), context); + Seq orderSpec = context.retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(exp, true)); + NamedExpression rowNumber = WindowSpecTransformer.buildRowNumber(partitionSpec, orderSpec); + LogicalPlan window = new org.apache.spark.sql.catalyst.plans.logical.Window( + seq(rowNumber), + partitionSpec, + orderSpec, + isNotNullFilter); + + // Build deduplication Filter ('_row_number_ <= n) + Expression filterExpr = new LessThanOrEqual( + rowNumber.toAttribute(), + new org.apache.spark.sql.catalyst.expressions.Literal(allowedDuplication, DataTypes.IntegerType)); + LogicalPlan deduplicationFilter = new org.apache.spark.sql.catalyst.plans.logical.Filter(filterExpr, window); + + // Build DataFrameDropColumns('_row_number_) for left + LogicalPlan left = new DataFrameDropColumns(seq(rowNumber.toAttribute()), deduplicationFilter); + + // Build Union + return new Union(seq(left, right), false, false); + }); + return context.getPlan(); + } else { + // | dedup 2 a, b keepempty=false + // DataFrameDropColumns('row_number_col) + // +- Filter ('_row_number_ <= n) + // +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST] + // +- Filter (isnotnull('a) AND isnotnull('b)) + // +- ... + // +- UnresolvedRelation + + // Build isnotnull Filter + Expression isNotNullExpr = buildIsNotNullFilterExpression(node, context); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); + + // Build Window + visitFieldList(node.getFields(), context); + Seq partitionSpec = context.retainAllNamedParseExpressions(exp -> exp); + visitFieldList(node.getFields(), context); + Seq orderSpec = context.retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(exp, true)); + NamedExpression rowNumber = WindowSpecTransformer.buildRowNumber(partitionSpec, orderSpec); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Window( + seq(rowNumber), + partitionSpec, + orderSpec, p)); + + // Build deduplication Filter ('_row_number_ <= n) + Expression filterExpr = new LessThanOrEqual( + rowNumber.toAttribute(), + new org.apache.spark.sql.catalyst.expressions.Literal(allowedDuplication, DataTypes.IntegerType)); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(filterExpr, p)); + + // Build DataFrameDropColumns('_row_number_) Spark 3.5.1+ required + return context.apply(p -> new DataFrameDropColumns(seq(rowNumber.toAttribute()), p)); + } } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java.orig b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java.orig new file mode 100644 index 000000000..812cbea82 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java.orig @@ -0,0 +1,595 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl; + +import org.apache.spark.sql.catalyst.TableIdentifier; +import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; +import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; +import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.Predicate; +import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.apache.spark.sql.catalyst.plans.logical.Aggregate; +import org.apache.spark.sql.catalyst.plans.logical.Deduplicate; +import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; +import org.apache.spark.sql.catalyst.plans.logical.Limit; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Union; +import org.apache.spark.sql.execution.command.DescribeTableCommand; +import org.apache.spark.sql.types.DataTypes; +import org.apache.spark.sql.util.CaseInsensitiveStringMap; +import org.opensearch.sql.ast.AbstractNodeVisitor; +import org.opensearch.sql.ast.expression.AggregateFunction; +import org.opensearch.sql.ast.expression.Alias; +import org.opensearch.sql.ast.expression.AllFields; +import org.opensearch.sql.ast.expression.And; +import org.opensearch.sql.ast.expression.Argument; +import org.opensearch.sql.ast.expression.BinaryExpression; +import org.opensearch.sql.ast.expression.Case; +import org.opensearch.sql.ast.expression.Compare; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.FieldsMapping; +import org.opensearch.sql.ast.expression.Function; +import org.opensearch.sql.ast.expression.In; +import org.opensearch.sql.ast.expression.Interval; +import org.opensearch.sql.ast.expression.Let; +import org.opensearch.sql.ast.expression.Literal; +import org.opensearch.sql.ast.expression.Not; +import org.opensearch.sql.ast.expression.Or; +import org.opensearch.sql.ast.expression.QualifiedName; +import org.opensearch.sql.ast.expression.Span; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.expression.WindowFunction; +import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.statement.Explain; +import org.opensearch.sql.ast.statement.Query; +import org.opensearch.sql.ast.statement.Statement; +import org.opensearch.sql.ast.tree.Aggregation; +import org.opensearch.sql.ast.tree.Correlation; +import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ast.tree.DescribeRelation; +import org.opensearch.sql.ast.tree.Eval; +import org.opensearch.sql.ast.tree.Filter; +import org.opensearch.sql.ast.tree.Head; +import org.opensearch.sql.ast.tree.Kmeans; +import org.opensearch.sql.ast.tree.Project; +import org.opensearch.sql.ast.tree.RareTopN; +import org.opensearch.sql.ast.tree.Relation; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ppl.utils.AggregatorTranslator; +import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; +import org.opensearch.sql.ppl.utils.ComparatorTransformer; +import org.opensearch.sql.ppl.utils.SortUtils; +import scala.Option; +import scala.Option$; +import scala.collection.Seq; + +import java.util.ArrayList; +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.function.BiFunction; +import java.util.stream.Collectors; + +import static java.util.Collections.emptyList; +import static java.util.List.of; +import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; +import static org.opensearch.sql.ppl.utils.JoinSpecTransformer.join; +import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; +import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; + +/** + * Utility class to traverse PPL logical plan and translate it into catalyst logical plan + */ +public class CatalystQueryPlanVisitor extends AbstractNodeVisitor { + + private final ExpressionAnalyzer expressionAnalyzer; + + public CatalystQueryPlanVisitor() { + this.expressionAnalyzer = new ExpressionAnalyzer(); + } + + public LogicalPlan visit(Statement plan, CatalystPlanContext context) { + return plan.accept(this, context); + } + + /** + * Handle Query Statement. + */ + @Override + public LogicalPlan visitQuery(Query node, CatalystPlanContext context) { + return node.getPlan().accept(this, context); + } + + @Override + public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { + return node.getStatement().accept(this, context); + } + + @Override + public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { + if (node instanceof DescribeRelation) { + TableIdentifier identifier; + if (node.getTableQualifiedName().getParts().size() == 1) { + identifier = new TableIdentifier(node.getTableQualifiedName().getParts().get(0)); + } else if (node.getTableQualifiedName().getParts().size() == 2) { + identifier = new TableIdentifier( + node.getTableQualifiedName().getParts().get(1), + Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0))); + } else { + throw new IllegalArgumentException("Invalid table name: " + node.getTableQualifiedName() + + " Syntax: [ database_name. ] table_name"); + } + return context.with( + new DescribeTableCommand( + identifier, + scala.collection.immutable.Map$.MODULE$.empty(), + false, + DescribeRelation$.MODULE$.getOutputAttrs())); + } + //regular sql algebraic relations + node.getTableName().forEach(t -> + // Resolving the qualifiedName which is composed of a datasource.schema.table + context.with(new UnresolvedRelation(seq(of(t.split("\\."))), CaseInsensitiveStringMap.empty(), false)) + ); + return context.getPlan(); + } + + @Override + public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + return context.apply(p -> { + Expression conditionExpression = visitExpression(node.getCondition(), context); + Optional innerConditionExpression = context.popNamedParseExpressions(); + return innerConditionExpression.map(expression -> new org.apache.spark.sql.catalyst.plans.logical.Filter(innerConditionExpression.get(), p)).orElse(null); + }); + } + + @Override + public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + context.reduce((left,right) -> { + visitFieldList(node.getFieldsList().stream().map(Field::new).collect(Collectors.toList()), context); + Seq fields = context.retainAllNamedParseExpressions(e -> e); + if(!Objects.isNull(node.getScope())) { + // scope - this is a time base expression that timeframes the join to a specific period : (Time-field-name, value, unit) + expressionAnalyzer.visitSpan(node.getScope(), context); + context.popNamedParseExpressions().get(); + } + expressionAnalyzer.visitCorrelationMapping(node.getMappingListContext(), context); + Seq mapping = context.retainAllNamedParseExpressions(e -> e); + return join(node.getCorrelationType(), fields, mapping, left, right); + }); + return context.getPlan(); + } + + @Override + public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + List aggsExpList = visitExpressionList(node.getAggExprList(), context); + List groupExpList = visitExpressionList(node.getGroupExprList(), context); + + if (!groupExpList.isEmpty()) { + //add group by fields to context + context.getGroupingParseExpressions().addAll(groupExpList); + } + + UnresolvedExpression span = node.getSpan(); + if (!Objects.isNull(span)) { + span.accept(this, context); + //add span's group alias field (most recent added expression) + context.getGroupingParseExpressions().add(context.getNamedParseExpressions().peek()); + } + // build the aggregation logical step + return extractedAggregation(context); + } + + private static LogicalPlan extractedAggregation(CatalystPlanContext context) { + Seq groupingExpression = context.retainAllGroupingNamedParseExpressions(p -> p); + Seq aggregateExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + return context.apply(p -> new Aggregate(groupingExpression, aggregateExpressions, p)); + } + + @Override + public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) { + expressionAnalyzer.visitAlias(node, context); + return context.getPlan(); + } + + @Override + public LogicalPlan visitProject(Project node, CatalystPlanContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + List expressionList = visitExpressionList(node.getProjectList(), context); + + // Create a projection list from the existing expressions + Seq projectList = seq(context.getNamedParseExpressions()); + if (!projectList.isEmpty()) { + Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + // build the plan with the projection step + child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + } + if (node.hasArgument()) { + Argument argument = node.getArgExprList().get(0); + //todo exclude the argument from the projected arguments list + } + return child; + } + + @Override + public LogicalPlan visitSort(Sort node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + visitFieldList(node.getSortList(), context); + Seq sortElements = context.retainAllNamedParseExpressions(exp -> SortUtils.getSortDirection(node, (NamedExpression) exp)); + return context.apply(p -> (LogicalPlan) new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, p)); + } + + @Override + public LogicalPlan visitHead(Head node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + return context.apply(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( + node.getSize(), DataTypes.IntegerType), p)); + } + + private void visitFieldList(List fieldList, CatalystPlanContext context) { + fieldList.forEach(field -> visitExpression(field, context)); + } + + private List visitExpressionList(List expressionList, CatalystPlanContext context) { + return expressionList.isEmpty() + ? emptyList() + : expressionList.stream().map(field -> visitExpression(field, context)) + .collect(Collectors.toList()); + } + + private Expression visitExpression(UnresolvedExpression expression, CatalystPlanContext context) { + return expressionAnalyzer.analyze(expression, context); + } + + @Override + public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { + LogicalPlan child = node.getChild().get(0).accept(this, context); + List aliases = new ArrayList<>(); + List letExpressions = node.getExpressionList(); + for(Let let : letExpressions) { + Alias alias = new Alias(let.getVar().getField().toString(), let.getExpression()); + aliases.add(alias); + } + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + List expressionList = visitExpressionList(aliases, context); + Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); + // build the plan with the projection step + child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); + return child; + } + + @Override + public LogicalPlan visitKmeans(Kmeans node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Kmeans"); + } + + @Override + public LogicalPlan visitIn(In node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : In"); + } + + @Override + public LogicalPlan visitCase(Case node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Case"); + } + + @Override + public LogicalPlan visitRareTopN(RareTopN node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : RareTopN"); + } + + @Override + public LogicalPlan visitWindowFunction(WindowFunction node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : WindowFunction"); + } + + @Override + public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { + node.getChild().get(0).accept(this, context); + List options = node.getOptions(); + Integer allowedDuplication = (Integer) options.get(0).getValue().getValue(); + Boolean keepEmpty = (Boolean) options.get(1).getValue().getValue(); + Boolean consecutive = (Boolean) options.get(2).getValue().getValue(); + if (allowedDuplication <= 0) { + throw new IllegalArgumentException("Number of duplicate events must be greater than 0"); + } + if (consecutive) { + // Spark is not able to remove only consecutive events + throw new UnsupportedOperationException("Consecutive deduplication is not supported"); + } + visitFieldList(node.getFields(), context); + // Columns to deduplicate + Seq dedupFields + = context.retainAllNamedParseExpressions(e -> (org.apache.spark.sql.catalyst.expressions.Attribute) e); + // Although we can also use the Window operator to translate this as allowedDuplication > 1 did, + // adding Aggregate operator could achieve better performance. + if (allowedDuplication == 1) { + if (keepEmpty) { + // Union + // :- Deduplicate ['a, 'b] + // : +- Filter (isnotnull('a) AND isnotnull('b) + // : +- Project + // : +- UnresolvedRelation + // +- Filter (isnull('a) OR isnull('a)) + // +- Project + // +- UnresolvedRelation + + context.apply(p -> { + Expression isNullExpr = buildIsNullFilterExpression(node, context); + LogicalPlan right = new org.apache.spark.sql.catalyst.plans.logical.Filter(isNullExpr, p); + + Expression isNotNullExpr = buildIsNotNullFilterExpression(node, context); + LogicalPlan left = + new Deduplicate(dedupFields, + new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); + return new Union(seq(left, right), false, false); + }); + return context.getPlan(); + } else { + // Deduplicate ['a, 'b] + // +- Filter (isnotnull('a) AND isnotnull('b)) + // +- Project + // +- UnresolvedRelation + + Expression isNotNullExpr = buildIsNotNullFilterExpression(node, context); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); + // Todo DeduplicateWithinWatermark in streaming dataset? + return context.apply(p -> new Deduplicate(dedupFields, p)); + } + } else { + // TODO + throw new UnsupportedOperationException("Number of duplicate events greater than 1 is not supported"); + } + } + + private Expression buildIsNotNullFilterExpression(Dedupe node, CatalystPlanContext context) { + visitFieldList(node.getFields(), context); + Seq isNotNullExpressions = + context.retainAllNamedParseExpressions( + org.apache.spark.sql.catalyst.expressions.IsNotNull$.MODULE$::apply); + + Expression isNotNullExpr; + if (isNotNullExpressions.size() == 1) { + isNotNullExpr = isNotNullExpressions.apply(0); + } else { + isNotNullExpr = isNotNullExpressions.reduce( + new scala.Function2() { + @Override + public Expression apply(Expression e1, Expression e2) { + return new org.apache.spark.sql.catalyst.expressions.And(e1, e2); + } + } + ); + } + return isNotNullExpr; + } + + private Expression buildIsNullFilterExpression(Dedupe node, CatalystPlanContext context) { + visitFieldList(node.getFields(), context); + Seq isNullExpressions = + context.retainAllNamedParseExpressions( + org.apache.spark.sql.catalyst.expressions.IsNull$.MODULE$::apply); + + Expression isNullExpr; + if (isNullExpressions.size() == 1) { + isNullExpr = isNullExpressions.apply(0); + } else { + isNullExpr = isNullExpressions.reduce( + new scala.Function2() { + @Override + public Expression apply(Expression e1, Expression e2) { + return new org.apache.spark.sql.catalyst.expressions.Or(e1, e2); + } + } + ); + } + return isNullExpr; + } + + /** + * Expression Analyzer. + */ + private static class ExpressionAnalyzer extends AbstractNodeVisitor { + + public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { + return unresolved.accept(this, context); + } + + @Override + public Expression visitLiteral(Literal node, CatalystPlanContext context) { + return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Literal( + translate(node.getValue(), node.getType()), translate(node.getType()))); + } + + /** + * generic binary (And, Or, Xor , ...) arithmetic expression resolver + * @param node + * @param transformer + * @param context + * @return + */ + public Expression visitBinaryArithmetic(BinaryExpression node, BiFunction transformer, CatalystPlanContext context) { + node.getLeft().accept(this, context); + Optional left = context.popNamedParseExpressions(); + node.getRight().accept(this, context); + Optional right = context.popNamedParseExpressions(); + if(left.isPresent() && right.isPresent()) { + return transformer.apply(left.get(),right.get()); + } else if(left.isPresent()) { + return context.getNamedParseExpressions().push(left.get()); + } else if(right.isPresent()) { + return context.getNamedParseExpressions().push(right.get()); + } + return null; + + } + + @Override + public Expression visitAnd(And node, CatalystPlanContext context) { + return visitBinaryArithmetic(node, + (left,right)-> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(left, right)), context); + } + + @Override + public Expression visitOr(Or node, CatalystPlanContext context) { + return visitBinaryArithmetic(node, + (left,right)-> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Or(left, right)), context); + } + + @Override + public Expression visitXor(Xor node, CatalystPlanContext context) { + return visitBinaryArithmetic(node, + (left,right)-> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(left, right)), context); + } + + @Override + public Expression visitNot(Not node, CatalystPlanContext context) { + node.getExpression().accept(this, context); + Optional arg = context.popNamedParseExpressions(); + return arg.map(expression -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Not(expression))).orElse(null); + } + + @Override + public Expression visitSpan(Span node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression field = (Expression) context.popNamedParseExpressions().get(); + node.getValue().accept(this, context); + Expression value = (Expression) context.popNamedParseExpressions().get(); + return context.getNamedParseExpressions().push(window(field, value, node.getUnit())); + } + + @Override + public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { + node.getField().accept(this, context); + Expression arg = (Expression) context.popNamedParseExpressions().get(); + Expression aggregator = AggregatorTranslator.aggregator(node, arg); + return context.getNamedParseExpressions().push(aggregator); + } + + @Override + public Expression visitCompare(Compare node, CatalystPlanContext context) { + analyze(node.getLeft(), context); + Optional left = context.popNamedParseExpressions(); + analyze(node.getRight(), context); + Optional right = context.popNamedParseExpressions(); + if (left.isPresent() && right.isPresent()) { + Predicate comparator = ComparatorTransformer.comparator(node, left.get(), right.get()); + return context.getNamedParseExpressions().push((org.apache.spark.sql.catalyst.expressions.Expression) comparator); + } + return null; + } + + @Override + public Expression visitQualifiedName(QualifiedName node, CatalystPlanContext context) { + List relation = findRelation(context.traversalContext()); + if (!relation.isEmpty()) { + Optional resolveField = resolveField(relation, node); + return resolveField.map(qualifiedName -> context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(qualifiedName.getParts())))) + .orElse(null); + } + return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); + } + + @Override + public Expression visitCorrelationMapping(FieldsMapping node, CatalystPlanContext context) { + return node.getChild().stream().map(expression -> + visitCompare((Compare) expression, context) + ).reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); + } + + @Override + public Expression visitAllFields(AllFields node, CatalystPlanContext context) { + // Case of aggregation step - no start projection can be added + if (context.getNamedParseExpressions().isEmpty()) { + // Create an UnresolvedStar for all-fields projection + context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); + } + return context.getNamedParseExpressions().peek(); + } + + @Override + public Expression visitAlias(Alias node, CatalystPlanContext context) { + node.getDelegated().accept(this, context); + Expression arg = context.popNamedParseExpressions().get(); + return context.getNamedParseExpressions().push( + org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(arg, + node.getAlias() != null ? node.getAlias() : node.getName(), + NamedExpression.newExprId(), + seq(new java.util.ArrayList()), + Option.empty(), + seq(new java.util.ArrayList()))); + } + + @Override + public Expression visitEval(Eval node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Eval"); + } + + @Override + public Expression visitFunction(Function node, CatalystPlanContext context) { + List arguments = + node.getFuncArgs().stream() + .map( + unresolvedExpression -> { + var ret = analyze(unresolvedExpression, context); + if (ret == null) { + throw new UnsupportedOperationException( + String.format("Invalid use of expression %s", unresolvedExpression)); + } else { + return context.popNamedParseExpressions().get(); + } + }) + .collect(Collectors.toList()); + Expression function = BuiltinFunctionTranslator.builtinFunction(node, arguments); + return context.getNamedParseExpressions().push(function); + } + + @Override + public Expression visitInterval(Interval node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Interval"); + } + + @Override + public Expression visitDedupe(Dedupe node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Dedupe"); + } + + @Override + public Expression visitIn(In node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : In"); + } + + @Override + public Expression visitKmeans(Kmeans node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Kmeans"); + } + + @Override + public Expression visitCase(Case node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : Case"); + } + + @Override + public Expression visitRareTopN(RareTopN node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : RareTopN"); + } + + @Override + public Expression visitWindowFunction(WindowFunction node, CatalystPlanContext context) { + throw new IllegalStateException("Not Supported operation : WindowFunction"); + } + } +} diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java index c215caec5..0e6ba2a1d 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/WindowSpecTransformer.java @@ -5,24 +5,37 @@ package org.opensearch.sql.ppl.utils; +import org.apache.spark.sql.catalyst.expressions.CurrentRow$; import org.apache.spark.sql.catalyst.expressions.Divide; import org.apache.spark.sql.catalyst.expressions.Expression; import org.apache.spark.sql.catalyst.expressions.Floor; -import org.apache.spark.sql.catalyst.expressions.Literal; import org.apache.spark.sql.catalyst.expressions.Multiply; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.RowFrame$; +import org.apache.spark.sql.catalyst.expressions.RowNumber; +import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.apache.spark.sql.catalyst.expressions.SpecifiedWindowFrame; import org.apache.spark.sql.catalyst.expressions.TimeWindow; -import org.apache.spark.sql.types.DateType$; -import org.apache.spark.sql.types.StringType$; +import org.apache.spark.sql.catalyst.expressions.UnboundedPreceding$; +import org.apache.spark.sql.catalyst.expressions.WindowExpression; +import org.apache.spark.sql.catalyst.expressions.WindowSpecDefinition; import org.opensearch.sql.ast.expression.SpanUnit; +import scala.Option; +import scala.collection.Seq; + +import java.util.ArrayList; import static java.lang.String.format; import static org.opensearch.sql.ast.expression.DataType.STRING; import static org.opensearch.sql.ast.expression.SpanUnit.NONE; import static org.opensearch.sql.ast.expression.SpanUnit.UNKNOWN; +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; public interface WindowSpecTransformer { + String ROW_NUMBER_COLUMN_NAME = "_row_number_"; + /** * create a static window buckets based on the given value * @@ -50,4 +63,20 @@ static org.apache.spark.sql.catalyst.expressions.Literal timeLiteral( Expression return new org.apache.spark.sql.catalyst.expressions.Literal( translate(format, STRING), translate(STRING)); } + + static NamedExpression buildRowNumber(Seq partitionSpec, Seq orderSpec) { + WindowExpression rowNumber = new WindowExpression( + new RowNumber(), + new WindowSpecDefinition( + partitionSpec, + orderSpec, + new SpecifiedWindowFrame(RowFrame$.MODULE$, UnboundedPreceding$.MODULE$, CurrentRow$.MODULE$))); + return org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply( + rowNumber, + ROW_NUMBER_COLUMN_NAME, + NamedExpression.newExprId(), + seq(new ArrayList()), + Option.empty(), + seq(new ArrayList())); + } } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupTranslatorTestSuite.scala index 34cfcbd90..08d7f1847 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupTranslatorTestSuite.scala @@ -7,13 +7,14 @@ package org.opensearch.flint.spark.ppl import org.opensearch.flint.spark.ppl.PlaneUtils.plan import org.opensearch.sql.ppl.{CatalystPlanContext, CatalystQueryPlanVisitor} +import org.opensearch.sql.ppl.utils.{SortUtils, WindowSpecTransformer} import org.scalatest.matchers.should.Matchers import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, UnresolvedStar} -import org.apache.spark.sql.catalyst.expressions.{And, IsNotNull, IsNull, NamedExpression, Or} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation} +import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, CurrentRow, IsNotNull, IsNull, LessThanOrEqual, Literal, NamedExpression, Or, RowFrame, RowNumber, SortOrder, SpecifiedWindowFrame, UnboundedPreceding, WindowExpression, WindowSpecDefinition} import org.apache.spark.sql.catalyst.plans.PlanTest -import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Filter, Project, Union} +import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Deduplicate, Filter, Project, Union, Window} class PPLLogicalPlanDedupTranslatorTestSuite extends SparkFunSuite @@ -229,40 +230,164 @@ class PPLLogicalPlanDedupTranslatorTestSuite assert(ex.getMessage === "Number of duplicate events must be greater than 0") } - // Todo - ignore("test dedup 2 a") { + test("test dedup 2 a") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( plan(pplParser, "source=table | dedup 2 a | fields a", false), context) + val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("a")) + val isNotNullFilter = + Filter(IsNotNull(UnresolvedAttribute("a")), UnresolvedRelation(Seq("table"))) + val windowExpression = WindowExpression( + RowNumber(), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("a"), Ascending) :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))) + val rowNumberAlias = Alias(windowExpression, WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)() + val partitionSpec = projectList + val orderSpec = Seq(SortUtils.sortOrder(UnresolvedAttribute("a"), true)) + val window = Window(Seq(rowNumberAlias), partitionSpec, orderSpec, isNotNullFilter) + val deduplicateFilter = + Filter( + LessThanOrEqual( + UnresolvedAttribute(WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME), + Literal(2)), + window) + val dropColumns = + DataFrameDropColumns( + Seq(UnresolvedAttribute(WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)), + deduplicateFilter) + val expectedPlan = Project(projectList, dropColumns) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) } - // Todo - ignore("test dedup 2 a, b, c") { + test("test dedup 2 a, b, c") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( plan(pplParser, "source=table | dedup 2 a, b, c | fields a, b, c", false), context) + val projectList: Seq[NamedExpression] = + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")) + val isNotNullFilter = Filter( + And( + And(IsNotNull(UnresolvedAttribute("a")), IsNotNull(UnresolvedAttribute("b"))), + IsNotNull(UnresolvedAttribute("c"))), + UnresolvedRelation(Seq("table"))) + val windowExpression = WindowExpression( + RowNumber(), + WindowSpecDefinition( + UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: UnresolvedAttribute("c") :: Nil, + SortOrder(UnresolvedAttribute("a"), Ascending) :: SortOrder( + UnresolvedAttribute("b"), + Ascending) :: SortOrder(UnresolvedAttribute("c"), Ascending) :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))) + val rowNumberAlias = Alias(windowExpression, WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)() + val partitionSpec = projectList + val orderSpec = Seq( + SortUtils.sortOrder(UnresolvedAttribute("a"), true), + SortUtils.sortOrder(UnresolvedAttribute("b"), true), + SortUtils.sortOrder(UnresolvedAttribute("c"), true)) + val window = Window(Seq(rowNumberAlias), partitionSpec, orderSpec, isNotNullFilter) + val deduplicateFilter = + Filter( + LessThanOrEqual( + UnresolvedAttribute(WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME), + Literal(2)), + window) + val dropColumns = + DataFrameDropColumns( + Seq(UnresolvedAttribute(WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)), + deduplicateFilter) + val expectedPlan = Project(projectList, dropColumns) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) } - // Todo - ignore("test dedup 2 a keepempty=true") { + test("test dedup 2 a keepempty=true") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( plan(pplParser, "source=table | dedup 2 a keepempty=true | fields a", false), context) + val projectList: Seq[NamedExpression] = Seq(UnresolvedAttribute("a")) + val isNotNullFilter = + Filter(IsNotNull(UnresolvedAttribute("a")), UnresolvedRelation(Seq("table"))) + val windowExpression = WindowExpression( + RowNumber(), + WindowSpecDefinition( + UnresolvedAttribute("a") :: Nil, + SortOrder(UnresolvedAttribute("a"), Ascending) :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))) + val rowNumberAlias = Alias(windowExpression, WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)() + val partitionSpec = projectList + val orderSpec = Seq(SortUtils.sortOrder(UnresolvedAttribute("a"), true)) + val window = Window(Seq(rowNumberAlias), partitionSpec, orderSpec, isNotNullFilter) + val deduplicateFilter = + Filter( + LessThanOrEqual( + UnresolvedAttribute(WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME), + Literal(2)), + window) + val dropColumns = + DataFrameDropColumns( + Seq(UnresolvedAttribute(WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)), + deduplicateFilter) + + val isNullFilter = Filter(IsNull(UnresolvedAttribute("a")), UnresolvedRelation(Seq("table"))) + val union = Union(dropColumns, isNullFilter) + val expectedPlan = Project(projectList, union) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) } - // Todo - ignore("test dedup 2 a, b, c keepempty=true") { + test("test dedup 2 a, b, c keepempty=true") { val context = new CatalystPlanContext val logPlan = planTransformer.visit( plan(pplParser, "source=table | dedup 2 a, b, c keepempty=true | fields a, b, c", false), context) + val projectList: Seq[NamedExpression] = + Seq(UnresolvedAttribute("a"), UnresolvedAttribute("b"), UnresolvedAttribute("c")) + val isNotNullFilter = Filter( + And( + And(IsNotNull(UnresolvedAttribute("a")), IsNotNull(UnresolvedAttribute("b"))), + IsNotNull(UnresolvedAttribute("c"))), + UnresolvedRelation(Seq("table"))) + val windowExpression = WindowExpression( + RowNumber(), + WindowSpecDefinition( + UnresolvedAttribute("a") :: UnresolvedAttribute("b") :: UnresolvedAttribute("c") :: Nil, + SortOrder(UnresolvedAttribute("a"), Ascending) :: SortOrder( + UnresolvedAttribute("b"), + Ascending) :: SortOrder(UnresolvedAttribute("c"), Ascending) :: Nil, + SpecifiedWindowFrame(RowFrame, UnboundedPreceding, CurrentRow))) + val rowNumberAlias = Alias(windowExpression, WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)() + val partitionSpec = projectList + val orderSpec = Seq( + SortUtils.sortOrder(UnresolvedAttribute("a"), true), + SortUtils.sortOrder(UnresolvedAttribute("b"), true), + SortUtils.sortOrder(UnresolvedAttribute("c"), true)) + val window = Window(Seq(rowNumberAlias), partitionSpec, orderSpec, isNotNullFilter) + val deduplicateFilter = + Filter( + LessThanOrEqual( + UnresolvedAttribute(WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME), + Literal(2)), + window) + val dropColumns = + DataFrameDropColumns( + Seq(UnresolvedAttribute(WindowSpecTransformer.ROW_NUMBER_COLUMN_NAME)), + deduplicateFilter) + + val isNullFilter = Filter( + Or( + Or(IsNull(UnresolvedAttribute("a")), IsNull(UnresolvedAttribute("b"))), + IsNull(UnresolvedAttribute("c"))), + UnresolvedRelation(Seq("table"))) + val union = Union(dropColumns, isNullFilter) + val expectedPlan = Project(projectList, union) + comparePlans(expectedPlan, logPlan, checkAnalysis = false) } test("test dedup 2 a consecutive=true") { From 52da2f2ee636eff4a33a9212c8a9d5a56afd5f9a Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Fri, 9 Aug 2024 09:15:12 +0800 Subject: [PATCH 2/6] update document Signed-off-by: Lantao Jin --- ppl-spark-integration/README.md | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/ppl-spark-integration/README.md b/ppl-spark-integration/README.md index 6b8f2ac5c..005d3a9ea 100644 --- a/ppl-spark-integration/README.md +++ b/ppl-spark-integration/README.md @@ -282,9 +282,11 @@ Limitation: Overriding existing field is unsupported, following queries throw ex - `source = table | dedup 1 a,b | fields a,b,c` - `source = table | dedup 1 a keepempty=true | fields a,b,c` - `source = table | dedup 1 a,b keepempty=true | fields a,b,c` -- `source = table | dedup 1 a consecutive=true| fields a,b,c` (Unsupported) -- `source = table | dedup 2 a | fields a,b,c` (Unsupported) - +- `source = table | dedup 2 a | fields a,b,c` +- `source = table | dedup 2 a,b | fields a,b,c` +- `source = table | dedup 2 a keepempty=true | fields a,b,c` +- `source = table | dedup 2 a,b keepempty=true | fields a,b,c` +- `source = table | dedup 1 a consecutive=true| fields a,b,c` (Consecutive deduplication is unsupported) For additional details on PPL commands - view [PPL Commands Docs](https://github.com/opensearch-project/sql/blob/main/docs/user/ppl/index.rst) From 0f0bdf225b7cbdb615b67ddbcd26d6a7babd6932 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Fri, 9 Aug 2024 10:11:12 +0800 Subject: [PATCH 3/6] remove the CatalystQueryPlanVisitor.java.orig Signed-off-by: Lantao Jin --- .../ppl/CatalystQueryPlanVisitor.java.orig | 595 ------------------ 1 file changed, 595 deletions(-) delete mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java.orig diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java.orig b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java.orig deleted file mode 100644 index 812cbea82..000000000 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java.orig +++ /dev/null @@ -1,595 +0,0 @@ -/* - * Copyright OpenSearch Contributors - * SPDX-License-Identifier: Apache-2.0 - */ - -package org.opensearch.sql.ppl; - -import org.apache.spark.sql.catalyst.TableIdentifier; -import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute$; -import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; -import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; -import org.apache.spark.sql.catalyst.expressions.Expression; -import org.apache.spark.sql.catalyst.expressions.NamedExpression; -import org.apache.spark.sql.catalyst.expressions.Predicate; -import org.apache.spark.sql.catalyst.expressions.SortOrder; -import org.apache.spark.sql.catalyst.plans.logical.Aggregate; -import org.apache.spark.sql.catalyst.plans.logical.Deduplicate; -import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; -import org.apache.spark.sql.catalyst.plans.logical.Limit; -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; -import org.apache.spark.sql.catalyst.plans.logical.Union; -import org.apache.spark.sql.execution.command.DescribeTableCommand; -import org.apache.spark.sql.types.DataTypes; -import org.apache.spark.sql.util.CaseInsensitiveStringMap; -import org.opensearch.sql.ast.AbstractNodeVisitor; -import org.opensearch.sql.ast.expression.AggregateFunction; -import org.opensearch.sql.ast.expression.Alias; -import org.opensearch.sql.ast.expression.AllFields; -import org.opensearch.sql.ast.expression.And; -import org.opensearch.sql.ast.expression.Argument; -import org.opensearch.sql.ast.expression.BinaryExpression; -import org.opensearch.sql.ast.expression.Case; -import org.opensearch.sql.ast.expression.Compare; -import org.opensearch.sql.ast.expression.Field; -import org.opensearch.sql.ast.expression.FieldsMapping; -import org.opensearch.sql.ast.expression.Function; -import org.opensearch.sql.ast.expression.In; -import org.opensearch.sql.ast.expression.Interval; -import org.opensearch.sql.ast.expression.Let; -import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.ast.expression.Not; -import org.opensearch.sql.ast.expression.Or; -import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.ast.expression.Span; -import org.opensearch.sql.ast.expression.UnresolvedExpression; -import org.opensearch.sql.ast.expression.WindowFunction; -import org.opensearch.sql.ast.expression.Xor; -import org.opensearch.sql.ast.statement.Explain; -import org.opensearch.sql.ast.statement.Query; -import org.opensearch.sql.ast.statement.Statement; -import org.opensearch.sql.ast.tree.Aggregation; -import org.opensearch.sql.ast.tree.Correlation; -import org.opensearch.sql.ast.tree.Dedupe; -import org.opensearch.sql.ast.tree.DescribeRelation; -import org.opensearch.sql.ast.tree.Eval; -import org.opensearch.sql.ast.tree.Filter; -import org.opensearch.sql.ast.tree.Head; -import org.opensearch.sql.ast.tree.Kmeans; -import org.opensearch.sql.ast.tree.Project; -import org.opensearch.sql.ast.tree.RareTopN; -import org.opensearch.sql.ast.tree.Relation; -import org.opensearch.sql.ast.tree.Sort; -import org.opensearch.sql.ppl.utils.AggregatorTranslator; -import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; -import org.opensearch.sql.ppl.utils.ComparatorTransformer; -import org.opensearch.sql.ppl.utils.SortUtils; -import scala.Option; -import scala.Option$; -import scala.collection.Seq; - -import java.util.ArrayList; -import java.util.List; -import java.util.Objects; -import java.util.Optional; -import java.util.function.BiFunction; -import java.util.stream.Collectors; - -import static java.util.Collections.emptyList; -import static java.util.List.of; -import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; -import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; -import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; -import static org.opensearch.sql.ppl.utils.JoinSpecTransformer.join; -import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; -import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; - -/** - * Utility class to traverse PPL logical plan and translate it into catalyst logical plan - */ -public class CatalystQueryPlanVisitor extends AbstractNodeVisitor { - - private final ExpressionAnalyzer expressionAnalyzer; - - public CatalystQueryPlanVisitor() { - this.expressionAnalyzer = new ExpressionAnalyzer(); - } - - public LogicalPlan visit(Statement plan, CatalystPlanContext context) { - return plan.accept(this, context); - } - - /** - * Handle Query Statement. - */ - @Override - public LogicalPlan visitQuery(Query node, CatalystPlanContext context) { - return node.getPlan().accept(this, context); - } - - @Override - public LogicalPlan visitExplain(Explain node, CatalystPlanContext context) { - return node.getStatement().accept(this, context); - } - - @Override - public LogicalPlan visitRelation(Relation node, CatalystPlanContext context) { - if (node instanceof DescribeRelation) { - TableIdentifier identifier; - if (node.getTableQualifiedName().getParts().size() == 1) { - identifier = new TableIdentifier(node.getTableQualifiedName().getParts().get(0)); - } else if (node.getTableQualifiedName().getParts().size() == 2) { - identifier = new TableIdentifier( - node.getTableQualifiedName().getParts().get(1), - Option$.MODULE$.apply(node.getTableQualifiedName().getParts().get(0))); - } else { - throw new IllegalArgumentException("Invalid table name: " + node.getTableQualifiedName() - + " Syntax: [ database_name. ] table_name"); - } - return context.with( - new DescribeTableCommand( - identifier, - scala.collection.immutable.Map$.MODULE$.empty(), - false, - DescribeRelation$.MODULE$.getOutputAttrs())); - } - //regular sql algebraic relations - node.getTableName().forEach(t -> - // Resolving the qualifiedName which is composed of a datasource.schema.table - context.with(new UnresolvedRelation(seq(of(t.split("\\."))), CaseInsensitiveStringMap.empty(), false)) - ); - return context.getPlan(); - } - - @Override - public LogicalPlan visitFilter(Filter node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); - return context.apply(p -> { - Expression conditionExpression = visitExpression(node.getCondition(), context); - Optional innerConditionExpression = context.popNamedParseExpressions(); - return innerConditionExpression.map(expression -> new org.apache.spark.sql.catalyst.plans.logical.Filter(innerConditionExpression.get(), p)).orElse(null); - }); - } - - @Override - public LogicalPlan visitCorrelation(Correlation node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); - context.reduce((left,right) -> { - visitFieldList(node.getFieldsList().stream().map(Field::new).collect(Collectors.toList()), context); - Seq fields = context.retainAllNamedParseExpressions(e -> e); - if(!Objects.isNull(node.getScope())) { - // scope - this is a time base expression that timeframes the join to a specific period : (Time-field-name, value, unit) - expressionAnalyzer.visitSpan(node.getScope(), context); - context.popNamedParseExpressions().get(); - } - expressionAnalyzer.visitCorrelationMapping(node.getMappingListContext(), context); - Seq mapping = context.retainAllNamedParseExpressions(e -> e); - return join(node.getCorrelationType(), fields, mapping, left, right); - }); - return context.getPlan(); - } - - @Override - public LogicalPlan visitAggregation(Aggregation node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); - List aggsExpList = visitExpressionList(node.getAggExprList(), context); - List groupExpList = visitExpressionList(node.getGroupExprList(), context); - - if (!groupExpList.isEmpty()) { - //add group by fields to context - context.getGroupingParseExpressions().addAll(groupExpList); - } - - UnresolvedExpression span = node.getSpan(); - if (!Objects.isNull(span)) { - span.accept(this, context); - //add span's group alias field (most recent added expression) - context.getGroupingParseExpressions().add(context.getNamedParseExpressions().peek()); - } - // build the aggregation logical step - return extractedAggregation(context); - } - - private static LogicalPlan extractedAggregation(CatalystPlanContext context) { - Seq groupingExpression = context.retainAllGroupingNamedParseExpressions(p -> p); - Seq aggregateExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); - return context.apply(p -> new Aggregate(groupingExpression, aggregateExpressions, p)); - } - - @Override - public LogicalPlan visitAlias(Alias node, CatalystPlanContext context) { - expressionAnalyzer.visitAlias(node, context); - return context.getPlan(); - } - - @Override - public LogicalPlan visitProject(Project node, CatalystPlanContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); - List expressionList = visitExpressionList(node.getProjectList(), context); - - // Create a projection list from the existing expressions - Seq projectList = seq(context.getNamedParseExpressions()); - if (!projectList.isEmpty()) { - Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); - // build the plan with the projection step - child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); - } - if (node.hasArgument()) { - Argument argument = node.getArgExprList().get(0); - //todo exclude the argument from the projected arguments list - } - return child; - } - - @Override - public LogicalPlan visitSort(Sort node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); - visitFieldList(node.getSortList(), context); - Seq sortElements = context.retainAllNamedParseExpressions(exp -> SortUtils.getSortDirection(node, (NamedExpression) exp)); - return context.apply(p -> (LogicalPlan) new org.apache.spark.sql.catalyst.plans.logical.Sort(sortElements, true, p)); - } - - @Override - public LogicalPlan visitHead(Head node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); - return context.apply(p -> (LogicalPlan) Limit.apply(new org.apache.spark.sql.catalyst.expressions.Literal( - node.getSize(), DataTypes.IntegerType), p)); - } - - private void visitFieldList(List fieldList, CatalystPlanContext context) { - fieldList.forEach(field -> visitExpression(field, context)); - } - - private List visitExpressionList(List expressionList, CatalystPlanContext context) { - return expressionList.isEmpty() - ? emptyList() - : expressionList.stream().map(field -> visitExpression(field, context)) - .collect(Collectors.toList()); - } - - private Expression visitExpression(UnresolvedExpression expression, CatalystPlanContext context) { - return expressionAnalyzer.analyze(expression, context); - } - - @Override - public LogicalPlan visitEval(Eval node, CatalystPlanContext context) { - LogicalPlan child = node.getChild().get(0).accept(this, context); - List aliases = new ArrayList<>(); - List letExpressions = node.getExpressionList(); - for(Let let : letExpressions) { - Alias alias = new Alias(let.getVar().getField().toString(), let.getExpression()); - aliases.add(alias); - } - if (context.getNamedParseExpressions().isEmpty()) { - // Create an UnresolvedStar for all-fields projection - context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); - } - List expressionList = visitExpressionList(aliases, context); - Seq projectExpressions = context.retainAllNamedParseExpressions(p -> (NamedExpression) p); - // build the plan with the projection step - child = context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Project(projectExpressions, p)); - return child; - } - - @Override - public LogicalPlan visitKmeans(Kmeans node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Kmeans"); - } - - @Override - public LogicalPlan visitIn(In node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : In"); - } - - @Override - public LogicalPlan visitCase(Case node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Case"); - } - - @Override - public LogicalPlan visitRareTopN(RareTopN node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : RareTopN"); - } - - @Override - public LogicalPlan visitWindowFunction(WindowFunction node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : WindowFunction"); - } - - @Override - public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { - node.getChild().get(0).accept(this, context); - List options = node.getOptions(); - Integer allowedDuplication = (Integer) options.get(0).getValue().getValue(); - Boolean keepEmpty = (Boolean) options.get(1).getValue().getValue(); - Boolean consecutive = (Boolean) options.get(2).getValue().getValue(); - if (allowedDuplication <= 0) { - throw new IllegalArgumentException("Number of duplicate events must be greater than 0"); - } - if (consecutive) { - // Spark is not able to remove only consecutive events - throw new UnsupportedOperationException("Consecutive deduplication is not supported"); - } - visitFieldList(node.getFields(), context); - // Columns to deduplicate - Seq dedupFields - = context.retainAllNamedParseExpressions(e -> (org.apache.spark.sql.catalyst.expressions.Attribute) e); - // Although we can also use the Window operator to translate this as allowedDuplication > 1 did, - // adding Aggregate operator could achieve better performance. - if (allowedDuplication == 1) { - if (keepEmpty) { - // Union - // :- Deduplicate ['a, 'b] - // : +- Filter (isnotnull('a) AND isnotnull('b) - // : +- Project - // : +- UnresolvedRelation - // +- Filter (isnull('a) OR isnull('a)) - // +- Project - // +- UnresolvedRelation - - context.apply(p -> { - Expression isNullExpr = buildIsNullFilterExpression(node, context); - LogicalPlan right = new org.apache.spark.sql.catalyst.plans.logical.Filter(isNullExpr, p); - - Expression isNotNullExpr = buildIsNotNullFilterExpression(node, context); - LogicalPlan left = - new Deduplicate(dedupFields, - new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); - return new Union(seq(left, right), false, false); - }); - return context.getPlan(); - } else { - // Deduplicate ['a, 'b] - // +- Filter (isnotnull('a) AND isnotnull('b)) - // +- Project - // +- UnresolvedRelation - - Expression isNotNullExpr = buildIsNotNullFilterExpression(node, context); - context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); - // Todo DeduplicateWithinWatermark in streaming dataset? - return context.apply(p -> new Deduplicate(dedupFields, p)); - } - } else { - // TODO - throw new UnsupportedOperationException("Number of duplicate events greater than 1 is not supported"); - } - } - - private Expression buildIsNotNullFilterExpression(Dedupe node, CatalystPlanContext context) { - visitFieldList(node.getFields(), context); - Seq isNotNullExpressions = - context.retainAllNamedParseExpressions( - org.apache.spark.sql.catalyst.expressions.IsNotNull$.MODULE$::apply); - - Expression isNotNullExpr; - if (isNotNullExpressions.size() == 1) { - isNotNullExpr = isNotNullExpressions.apply(0); - } else { - isNotNullExpr = isNotNullExpressions.reduce( - new scala.Function2() { - @Override - public Expression apply(Expression e1, Expression e2) { - return new org.apache.spark.sql.catalyst.expressions.And(e1, e2); - } - } - ); - } - return isNotNullExpr; - } - - private Expression buildIsNullFilterExpression(Dedupe node, CatalystPlanContext context) { - visitFieldList(node.getFields(), context); - Seq isNullExpressions = - context.retainAllNamedParseExpressions( - org.apache.spark.sql.catalyst.expressions.IsNull$.MODULE$::apply); - - Expression isNullExpr; - if (isNullExpressions.size() == 1) { - isNullExpr = isNullExpressions.apply(0); - } else { - isNullExpr = isNullExpressions.reduce( - new scala.Function2() { - @Override - public Expression apply(Expression e1, Expression e2) { - return new org.apache.spark.sql.catalyst.expressions.Or(e1, e2); - } - } - ); - } - return isNullExpr; - } - - /** - * Expression Analyzer. - */ - private static class ExpressionAnalyzer extends AbstractNodeVisitor { - - public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { - return unresolved.accept(this, context); - } - - @Override - public Expression visitLiteral(Literal node, CatalystPlanContext context) { - return context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Literal( - translate(node.getValue(), node.getType()), translate(node.getType()))); - } - - /** - * generic binary (And, Or, Xor , ...) arithmetic expression resolver - * @param node - * @param transformer - * @param context - * @return - */ - public Expression visitBinaryArithmetic(BinaryExpression node, BiFunction transformer, CatalystPlanContext context) { - node.getLeft().accept(this, context); - Optional left = context.popNamedParseExpressions(); - node.getRight().accept(this, context); - Optional right = context.popNamedParseExpressions(); - if(left.isPresent() && right.isPresent()) { - return transformer.apply(left.get(),right.get()); - } else if(left.isPresent()) { - return context.getNamedParseExpressions().push(left.get()); - } else if(right.isPresent()) { - return context.getNamedParseExpressions().push(right.get()); - } - return null; - - } - - @Override - public Expression visitAnd(And node, CatalystPlanContext context) { - return visitBinaryArithmetic(node, - (left,right)-> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.And(left, right)), context); - } - - @Override - public Expression visitOr(Or node, CatalystPlanContext context) { - return visitBinaryArithmetic(node, - (left,right)-> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Or(left, right)), context); - } - - @Override - public Expression visitXor(Xor node, CatalystPlanContext context) { - return visitBinaryArithmetic(node, - (left,right)-> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.BitwiseXor(left, right)), context); - } - - @Override - public Expression visitNot(Not node, CatalystPlanContext context) { - node.getExpression().accept(this, context); - Optional arg = context.popNamedParseExpressions(); - return arg.map(expression -> context.getNamedParseExpressions().push(new org.apache.spark.sql.catalyst.expressions.Not(expression))).orElse(null); - } - - @Override - public Expression visitSpan(Span node, CatalystPlanContext context) { - node.getField().accept(this, context); - Expression field = (Expression) context.popNamedParseExpressions().get(); - node.getValue().accept(this, context); - Expression value = (Expression) context.popNamedParseExpressions().get(); - return context.getNamedParseExpressions().push(window(field, value, node.getUnit())); - } - - @Override - public Expression visitAggregateFunction(AggregateFunction node, CatalystPlanContext context) { - node.getField().accept(this, context); - Expression arg = (Expression) context.popNamedParseExpressions().get(); - Expression aggregator = AggregatorTranslator.aggregator(node, arg); - return context.getNamedParseExpressions().push(aggregator); - } - - @Override - public Expression visitCompare(Compare node, CatalystPlanContext context) { - analyze(node.getLeft(), context); - Optional left = context.popNamedParseExpressions(); - analyze(node.getRight(), context); - Optional right = context.popNamedParseExpressions(); - if (left.isPresent() && right.isPresent()) { - Predicate comparator = ComparatorTransformer.comparator(node, left.get(), right.get()); - return context.getNamedParseExpressions().push((org.apache.spark.sql.catalyst.expressions.Expression) comparator); - } - return null; - } - - @Override - public Expression visitQualifiedName(QualifiedName node, CatalystPlanContext context) { - List relation = findRelation(context.traversalContext()); - if (!relation.isEmpty()) { - Optional resolveField = resolveField(relation, node); - return resolveField.map(qualifiedName -> context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(qualifiedName.getParts())))) - .orElse(null); - } - return context.getNamedParseExpressions().push(UnresolvedAttribute$.MODULE$.apply(seq(node.getParts()))); - } - - @Override - public Expression visitCorrelationMapping(FieldsMapping node, CatalystPlanContext context) { - return node.getChild().stream().map(expression -> - visitCompare((Compare) expression, context) - ).reduce(org.apache.spark.sql.catalyst.expressions.And::new).orElse(null); - } - - @Override - public Expression visitAllFields(AllFields node, CatalystPlanContext context) { - // Case of aggregation step - no start projection can be added - if (context.getNamedParseExpressions().isEmpty()) { - // Create an UnresolvedStar for all-fields projection - context.getNamedParseExpressions().push(UnresolvedStar$.MODULE$.apply(Option.>empty())); - } - return context.getNamedParseExpressions().peek(); - } - - @Override - public Expression visitAlias(Alias node, CatalystPlanContext context) { - node.getDelegated().accept(this, context); - Expression arg = context.popNamedParseExpressions().get(); - return context.getNamedParseExpressions().push( - org.apache.spark.sql.catalyst.expressions.Alias$.MODULE$.apply(arg, - node.getAlias() != null ? node.getAlias() : node.getName(), - NamedExpression.newExprId(), - seq(new java.util.ArrayList()), - Option.empty(), - seq(new java.util.ArrayList()))); - } - - @Override - public Expression visitEval(Eval node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Eval"); - } - - @Override - public Expression visitFunction(Function node, CatalystPlanContext context) { - List arguments = - node.getFuncArgs().stream() - .map( - unresolvedExpression -> { - var ret = analyze(unresolvedExpression, context); - if (ret == null) { - throw new UnsupportedOperationException( - String.format("Invalid use of expression %s", unresolvedExpression)); - } else { - return context.popNamedParseExpressions().get(); - } - }) - .collect(Collectors.toList()); - Expression function = BuiltinFunctionTranslator.builtinFunction(node, arguments); - return context.getNamedParseExpressions().push(function); - } - - @Override - public Expression visitInterval(Interval node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Interval"); - } - - @Override - public Expression visitDedupe(Dedupe node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Dedupe"); - } - - @Override - public Expression visitIn(In node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : In"); - } - - @Override - public Expression visitKmeans(Kmeans node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Kmeans"); - } - - @Override - public Expression visitCase(Case node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : Case"); - } - - @Override - public Expression visitRareTopN(RareTopN node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : RareTopN"); - } - - @Override - public Expression visitWindowFunction(WindowFunction node, CatalystPlanContext context) { - throw new IllegalStateException("Not Supported operation : WindowFunction"); - } - } -} From bfa6e0148ae933eae578b2593210864af0335225 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Tue, 27 Aug 2024 14:37:18 +0800 Subject: [PATCH 4/6] refactor Signed-off-by: Lantao Jin --- ...scala => FlintSparkPPLDedupeITSuite.scala} | 2 +- .../sql/ppl/CatalystQueryPlanVisitor.java | 169 +------------- .../sql/ppl/utils/DedupeTransformer.java | 211 ++++++++++++++++++ ...ogicalPlanDedupeTranslatorTestSuite.scala} | 2 +- 4 files changed, 223 insertions(+), 161 deletions(-) rename integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/{FlintSparkPPLDedupITSuite.scala => FlintSparkPPLDedupeITSuite.scala} (99%) create mode 100644 ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java rename ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/{PPLLogicalPlanDedupTranslatorTestSuite.scala => PPLLogicalPlanDedupeTranslatorTestSuite.scala} (99%) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupeITSuite.scala similarity index 99% rename from integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupITSuite.scala rename to integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupeITSuite.scala index 8ce078176..3270a7fd5 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupeITSuite.scala @@ -11,7 +11,7 @@ import org.apache.spark.sql.catalyst.expressions.{And, IsNotNull, IsNull, Or} import org.apache.spark.sql.catalyst.plans.logical.{Deduplicate, Filter, LogicalPlan, Project, Union} import org.apache.spark.sql.streaming.StreamTest -class FlintSparkPPLDedupITSuite +class FlintSparkPPLDedupeITSuite extends QueryTest with LogicalPlanTestUtils with FlintPPLSuite diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 294ac5226..05f46ab42 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -10,17 +10,13 @@ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation; import org.apache.spark.sql.catalyst.analysis.UnresolvedStar$; import org.apache.spark.sql.catalyst.expressions.Expression; -import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; import org.apache.spark.sql.catalyst.expressions.NamedExpression; import org.apache.spark.sql.catalyst.expressions.Predicate; import org.apache.spark.sql.catalyst.expressions.SortOrder; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; -import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns; -import org.apache.spark.sql.catalyst.plans.logical.Deduplicate; import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; import org.apache.spark.sql.catalyst.plans.logical.Limit; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; -import org.apache.spark.sql.catalyst.plans.logical.Union; import org.apache.spark.sql.execution.command.DescribeTableCommand; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.util.CaseInsensitiveStringMap; @@ -66,7 +62,6 @@ import org.opensearch.sql.ppl.utils.BuiltinFunctionTranslator; import org.opensearch.sql.ppl.utils.ComparatorTransformer; import org.opensearch.sql.ppl.utils.SortUtils; -import org.opensearch.sql.ppl.utils.WindowSpecTransformer; import scala.Option; import scala.Option$; import scala.collection.Seq; @@ -83,6 +78,10 @@ import static org.opensearch.sql.ppl.CatalystPlanContext.findRelation; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; import static org.opensearch.sql.ppl.utils.DataTypeTransformer.translate; +import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEvents; +import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainMultipleDuplicateEventsAndKeepEmpty; +import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainOneDuplicateEvent; +import static org.opensearch.sql.ppl.utils.DedupeTransformer.retainOneDuplicateEventAndKeepEmpty; import static org.opensearch.sql.ppl.utils.JoinSpecTransformer.join; import static org.opensearch.sql.ppl.utils.RelationUtils.resolveField; import static org.opensearch.sql.ppl.utils.WindowSpecTransformer.window; @@ -315,177 +314,29 @@ public LogicalPlan visitDedupe(Dedupe node, CatalystPlanContext context) { } visitFieldList(node.getFields(), context); // Columns to deduplicate - Seq dedupFields + Seq dedupeFields = context.retainAllNamedParseExpressions(e -> (org.apache.spark.sql.catalyst.expressions.Attribute) e); // Although we can also use the Window operator to translate this as allowedDuplication > 1 did, // adding Aggregate operator could achieve better performance. if (allowedDuplication == 1) { if (keepEmpty) { - // | dedup a, b keepempty=true - // Union - // :- Deduplicate ['a, 'b] - // : +- Filter (isnotnull('a) AND isnotnull('b) - // : +- ... - // : +- UnresolvedRelation - // +- Filter (isnull('a) OR isnull('a)) - // +- ... - // +- UnresolvedRelation - - context.apply(p -> { - Expression isNullExpr = buildIsNullFilterExpression(node, context); - LogicalPlan right = new org.apache.spark.sql.catalyst.plans.logical.Filter(isNullExpr, p); - - Expression isNotNullExpr = buildIsNotNullFilterExpression(node, context); - LogicalPlan left = - new Deduplicate(dedupFields, - new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); - return new Union(seq(left, right), false, false); - }); - return context.getPlan(); + return retainOneDuplicateEventAndKeepEmpty(node, dedupeFields, expressionAnalyzer, context); } else { - // | dedup a, b keepempty=false - // Deduplicate ['a, 'b] - // +- Filter (isnotnull('a) AND isnotnull('b)) - // +- ... - // +- UnresolvedRelation - - Expression isNotNullExpr = buildIsNotNullFilterExpression(node, context); - context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); - // Todo DeduplicateWithinWatermark in streaming dataset? - return context.apply(p -> new Deduplicate(dedupFields, p)); + return retainOneDuplicateEvent(node, dedupeFields, expressionAnalyzer, context); } } else { if (keepEmpty) { - // | dedup 2 a, b keepempty=true - // Union - //:- DataFrameDropColumns('_row_number_) - //: +- Filter ('_row_number_ <= 2) - //: +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST] - //: +- Filter (isnotnull('a) AND isnotnull('b)) - //: +- ... - //: +- UnresolvedRelation - //+- Filter (isnull('a) OR isnull('b)) - // +- ... - // +- UnresolvedRelation - - context.apply(p -> { - // Build isnull Filter for right - Expression isNullExpr = buildIsNullFilterExpression(node, context); - LogicalPlan right = new org.apache.spark.sql.catalyst.plans.logical.Filter(isNullExpr, p); - - // Build isnotnull Filter - Expression isNotNullExpr = buildIsNotNullFilterExpression(node, context); - LogicalPlan isNotNullFilter = new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p); - - // Build Window - visitFieldList(node.getFields(), context); - Seq partitionSpec = context.retainAllNamedParseExpressions(exp -> exp); - visitFieldList(node.getFields(), context); - Seq orderSpec = context.retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(exp, true)); - NamedExpression rowNumber = WindowSpecTransformer.buildRowNumber(partitionSpec, orderSpec); - LogicalPlan window = new org.apache.spark.sql.catalyst.plans.logical.Window( - seq(rowNumber), - partitionSpec, - orderSpec, - isNotNullFilter); - - // Build deduplication Filter ('_row_number_ <= n) - Expression filterExpr = new LessThanOrEqual( - rowNumber.toAttribute(), - new org.apache.spark.sql.catalyst.expressions.Literal(allowedDuplication, DataTypes.IntegerType)); - LogicalPlan deduplicationFilter = new org.apache.spark.sql.catalyst.plans.logical.Filter(filterExpr, window); - - // Build DataFrameDropColumns('_row_number_) for left - LogicalPlan left = new DataFrameDropColumns(seq(rowNumber.toAttribute()), deduplicationFilter); - - // Build Union - return new Union(seq(left, right), false, false); - }); - return context.getPlan(); + return retainMultipleDuplicateEventsAndKeepEmpty(node, allowedDuplication, expressionAnalyzer, context); } else { - // | dedup 2 a, b keepempty=false - // DataFrameDropColumns('row_number_col) - // +- Filter ('_row_number_ <= n) - // +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST] - // +- Filter (isnotnull('a) AND isnotnull('b)) - // +- ... - // +- UnresolvedRelation - - // Build isnotnull Filter - Expression isNotNullExpr = buildIsNotNullFilterExpression(node, context); - context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); - - // Build Window - visitFieldList(node.getFields(), context); - Seq partitionSpec = context.retainAllNamedParseExpressions(exp -> exp); - visitFieldList(node.getFields(), context); - Seq orderSpec = context.retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(exp, true)); - NamedExpression rowNumber = WindowSpecTransformer.buildRowNumber(partitionSpec, orderSpec); - context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Window( - seq(rowNumber), - partitionSpec, - orderSpec, p)); - - // Build deduplication Filter ('_row_number_ <= n) - Expression filterExpr = new LessThanOrEqual( - rowNumber.toAttribute(), - new org.apache.spark.sql.catalyst.expressions.Literal(allowedDuplication, DataTypes.IntegerType)); - context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(filterExpr, p)); - - // Build DataFrameDropColumns('_row_number_) Spark 3.5.1+ required - return context.apply(p -> new DataFrameDropColumns(seq(rowNumber.toAttribute()), p)); + return retainMultipleDuplicateEvents(node, allowedDuplication, expressionAnalyzer, context); } } } - private Expression buildIsNotNullFilterExpression(Dedupe node, CatalystPlanContext context) { - visitFieldList(node.getFields(), context); - Seq isNotNullExpressions = - context.retainAllNamedParseExpressions( - org.apache.spark.sql.catalyst.expressions.IsNotNull$.MODULE$::apply); - - Expression isNotNullExpr; - if (isNotNullExpressions.size() == 1) { - isNotNullExpr = isNotNullExpressions.apply(0); - } else { - isNotNullExpr = isNotNullExpressions.reduce( - new scala.Function2() { - @Override - public Expression apply(Expression e1, Expression e2) { - return new org.apache.spark.sql.catalyst.expressions.And(e1, e2); - } - } - ); - } - return isNotNullExpr; - } - - private Expression buildIsNullFilterExpression(Dedupe node, CatalystPlanContext context) { - visitFieldList(node.getFields(), context); - Seq isNullExpressions = - context.retainAllNamedParseExpressions( - org.apache.spark.sql.catalyst.expressions.IsNull$.MODULE$::apply); - - Expression isNullExpr; - if (isNullExpressions.size() == 1) { - isNullExpr = isNullExpressions.apply(0); - } else { - isNullExpr = isNullExpressions.reduce( - new scala.Function2() { - @Override - public Expression apply(Expression e1, Expression e2) { - return new org.apache.spark.sql.catalyst.expressions.Or(e1, e2); - } - } - ); - } - return isNullExpr; - } - /** * Expression Analyzer. */ - private static class ExpressionAnalyzer extends AbstractNodeVisitor { + public static class ExpressionAnalyzer extends AbstractNodeVisitor { public Expression analyze(UnresolvedExpression unresolved, CatalystPlanContext context) { return unresolved.accept(this, context); diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java new file mode 100644 index 000000000..bda78c548 --- /dev/null +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java @@ -0,0 +1,211 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.ppl.utils; + +import org.apache.spark.sql.catalyst.expressions.Attribute; +import org.apache.spark.sql.catalyst.expressions.Expression; +import org.apache.spark.sql.catalyst.expressions.LessThanOrEqual; +import org.apache.spark.sql.catalyst.expressions.NamedExpression; +import org.apache.spark.sql.catalyst.expressions.SortOrder; +import org.apache.spark.sql.catalyst.plans.logical.DataFrameDropColumns; +import org.apache.spark.sql.catalyst.plans.logical.Deduplicate; +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; +import org.apache.spark.sql.catalyst.plans.logical.Union; +import org.apache.spark.sql.types.DataTypes; +import org.opensearch.sql.ast.expression.Field; +import org.opensearch.sql.ast.expression.UnresolvedExpression; +import org.opensearch.sql.ast.tree.Dedupe; +import org.opensearch.sql.ppl.CatalystPlanContext; +import org.opensearch.sql.ppl.CatalystQueryPlanVisitor.ExpressionAnalyzer; +import scala.collection.Seq; + +import java.util.List; + +import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; + +public interface DedupeTransformer { + + /** + * | dedup a, b keepempty=true + * Union + * :- Deduplicate ['a, 'b] + * : +- Filter (isnotnull('a) AND isnotnull('b)) + * : +- ... + * : +- UnresolvedRelation + * +- Filter (isnull('a) OR isnull('a)) + * +- ... + * +- UnresolvedRelation + */ + static LogicalPlan retainOneDuplicateEventAndKeepEmpty( + Dedupe node, + Seq dedupeFields, + ExpressionAnalyzer expressionAnalyzer, + CatalystPlanContext context) { + context.apply(p -> { + Expression isNullExpr = buildIsNullFilterExpression(node, expressionAnalyzer, context); + LogicalPlan right = new org.apache.spark.sql.catalyst.plans.logical.Filter(isNullExpr, p); + + Expression isNotNullExpr = buildIsNotNullFilterExpression(node, expressionAnalyzer, context); + LogicalPlan left = + new Deduplicate(dedupeFields, + new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); + return new Union(seq(left, right), false, false); + }); + return context.getPlan(); + } + + /** + * | dedup a, b keepempty=false + * Deduplicate ['a, 'b] + * +- Filter (isnotnull('a) AND isnotnull('b)) + * +- ... + * +- UnresolvedRelation + */ + static LogicalPlan retainOneDuplicateEvent( + Dedupe node, + Seq dedupeFields, + ExpressionAnalyzer expressionAnalyzer, + CatalystPlanContext context) { + Expression isNotNullExpr = buildIsNotNullFilterExpression(node, expressionAnalyzer, context); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); + // Todo DeduplicateWithinWatermark in streaming dataset? + return context.apply(p -> new Deduplicate(dedupeFields, p)); + } + + /** + * | dedup 2 a, b keepempty=true + * Union + * :- DataFrameDropColumns('_row_number_) + * : +- Filter ('_row_number_ <= 2) + * : +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowframe(RowFrame, unboundedpreceding$(), currentrow$())) AS _row_number_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST] + * : +- Filter (isnotnull('a) AND isnotnull('b)) + * : +- ... + * : +- UnresolvedRelation + * +- Filter (isnull('a) OR isnull('b)) + * +- ... + * +- UnresolvedRelation + */ + static LogicalPlan retainMultipleDuplicateEventsAndKeepEmpty( + Dedupe node, + Integer allowedDuplication, + ExpressionAnalyzer expressionAnalyzer, + CatalystPlanContext context) { + context.apply(p -> { + // Build isnull Filter for right + Expression isNullExpr = buildIsNullFilterExpression(node, expressionAnalyzer, context); + LogicalPlan right = new org.apache.spark.sql.catalyst.plans.logical.Filter(isNullExpr, p); + + // Build isnotnull Filter + Expression isNotNullExpr = buildIsNotNullFilterExpression(node, expressionAnalyzer, context); + LogicalPlan isNotNullFilter = new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p); + + // Build Window + visitFieldList(node.getFields(), expressionAnalyzer, context); + Seq partitionSpec = context.retainAllNamedParseExpressions(exp -> exp); + visitFieldList(node.getFields(), expressionAnalyzer, context); + Seq orderSpec = context.retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(exp, true)); + NamedExpression rowNumber = WindowSpecTransformer.buildRowNumber(partitionSpec, orderSpec); + LogicalPlan window = new org.apache.spark.sql.catalyst.plans.logical.Window( + seq(rowNumber), + partitionSpec, + orderSpec, + isNotNullFilter); + + // Build deduplication Filter ('_row_number_ <= n) + Expression filterExpr = new LessThanOrEqual( + rowNumber.toAttribute(), + new org.apache.spark.sql.catalyst.expressions.Literal(allowedDuplication, DataTypes.IntegerType)); + LogicalPlan deduplicationFilter = new org.apache.spark.sql.catalyst.plans.logical.Filter(filterExpr, window); + + // Build DataFrameDropColumns('_row_number_) for left + LogicalPlan left = new DataFrameDropColumns(seq(rowNumber.toAttribute()), deduplicationFilter); + + // Build Union + return new Union(seq(left, right), false, false); + }); + return context.getPlan(); + } + + /** + * | dedup 2 a, b keepempty=false + * DataFrameDropColumns('_row_number_) + * +- Filter ('_row_number_ <= n) + * +- Window [row_number() windowspecdefinition('a, 'b, 'a ASC NULLS FIRST, 'b ASC NULLS FIRST, specifiedwindowoundedpreceding$(), currentrow$())) AS _row_number_], ['a, 'b], ['a ASC NULLS FIRST, 'b ASC NULLS FIRST] + * +- Filter (isnotnull('a) AND isnotnull('b)) + * +- ... + * +- UnresolvedRelation + */ + static LogicalPlan retainMultipleDuplicateEvents( + Dedupe node, + Integer allowedDuplication, + ExpressionAnalyzer expressionAnalyzer, + CatalystPlanContext context) { + // Build isnotnull Filter + Expression isNotNullExpr = buildIsNotNullFilterExpression(node, expressionAnalyzer, context); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); + + // Build Window + visitFieldList(node.getFields(), expressionAnalyzer, context); + Seq partitionSpec = context.retainAllNamedParseExpressions(exp -> exp); + visitFieldList(node.getFields(), expressionAnalyzer ,context); + Seq orderSpec = context.retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(exp, true)); + NamedExpression rowNumber = WindowSpecTransformer.buildRowNumber(partitionSpec, orderSpec); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Window( + seq(rowNumber), + partitionSpec, + orderSpec, p)); + + // Build deduplication Filter ('_row_number_ <= n) + Expression filterExpr = new LessThanOrEqual( + rowNumber.toAttribute(), + new org.apache.spark.sql.catalyst.expressions.Literal(allowedDuplication, DataTypes.IntegerType)); + context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(filterExpr, p)); + + return context.apply(p -> new DataFrameDropColumns(seq(rowNumber.toAttribute()), p)); + } + + static Expression buildIsNotNullFilterExpression(Dedupe node, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { + visitFieldList(node.getFields(), expressionAnalyzer, context); + Seq isNotNullExpressions = + context.retainAllNamedParseExpressions( + org.apache.spark.sql.catalyst.expressions.IsNotNull$.MODULE$::apply); + + Expression isNotNullExpr; + if (isNotNullExpressions.size() == 1) { + isNotNullExpr = isNotNullExpressions.apply(0); + } else { + isNotNullExpr = isNotNullExpressions.reduce( + (e1, e2) -> new org.apache.spark.sql.catalyst.expressions.And(e1, e2) + ); + } + return isNotNullExpr; + } + + private static Expression buildIsNullFilterExpression(Dedupe node, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { + visitFieldList(node.getFields(), expressionAnalyzer, context); + Seq isNullExpressions = + context.retainAllNamedParseExpressions( + org.apache.spark.sql.catalyst.expressions.IsNull$.MODULE$::apply); + + Expression isNullExpr; + if (isNullExpressions.size() == 1) { + isNullExpr = isNullExpressions.apply(0); + } else { + isNullExpr = isNullExpressions.reduce( + (e1, e2) -> new org.apache.spark.sql.catalyst.expressions.Or(e1, e2) + ); + } + return isNullExpr; + } + + static void visitFieldList(List fieldList, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { + fieldList.forEach(field -> visitExpression(field, expressionAnalyzer, context)); + } + + static Expression visitExpression(UnresolvedExpression expression, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { + return expressionAnalyzer.analyze(expression, context); + } +} diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupeTranslatorTestSuite.scala similarity index 99% rename from ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupTranslatorTestSuite.scala rename to ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupeTranslatorTestSuite.scala index 08d7f1847..23222c2e3 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanDedupeTranslatorTestSuite.scala @@ -16,7 +16,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, And, Ascending, Current import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Deduplicate, Filter, Project, Union, Window} -class PPLLogicalPlanDedupTranslatorTestSuite +class PPLLogicalPlanDedupeTranslatorTestSuite extends SparkFunSuite with PlanTest with LogicalPlanTestUtils From ec0efe0dc27c0717d3fa47801b5720d88b4e460c Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Tue, 27 Aug 2024 14:59:11 +0800 Subject: [PATCH 5/6] add IT for reordering field list Signed-off-by: Lantao Jin --- .../ppl/FlintSparkPPLDedupeITSuite.scala | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupeITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupeITSuite.scala index 3270a7fd5..96b459769 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupeITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupeITSuite.scala @@ -307,4 +307,40 @@ class FlintSparkPPLDedupeITSuite | """.stripMargin)) assert(ex.getMessage.contains("Consecutive deduplication is not supported")) } + + test("test dedupe 1 category, name - reorder field list won't impact output order") { + val frame1 = sql(s""" + | source = $testTable | dedup 1 name, category + | """.stripMargin) + + val results1: Array[Row] = frame1.drop("id").collect() + + val frame2 = sql(s""" + | source = $testTable | dedup 1 category, name + | """.stripMargin) + + val results2: Array[Row] = frame2.drop("id").collect() + implicit val twoColsRowOrdering: Ordering[Row] = + Ordering.by[Row, (String, String)](row => (row.getAs(0), row.getAs(1))) + assert(results1.sorted.sameElements(results2.sorted)) + } + + test("test dedupe 2 category, name - reorder field list won't impact output order") { + val frame1 = sql(s""" + | source = $testTable | dedup 2 name, category + | """.stripMargin) + + val results1: Array[Row] = frame1.drop("id").collect() + + val frame2 = sql(s""" + | source = $testTable | dedup 2 category, name + | """.stripMargin) + + val results2: Array[Row] = frame2.drop("id").collect() + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](row => { + val value = row.getAs[String](0) + if (value == null) String.valueOf(Int.MaxValue) else value + }) + assert(results1.sorted.sameElements(results2.sorted)) + } } From 9ef0b3cfcf74dd817ef258dbd51b1330ef004ef2 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Tue, 27 Aug 2024 15:25:04 +0800 Subject: [PATCH 6/6] remove useless code Signed-off-by: Lantao Jin --- .../ppl/FlintSparkPPLDedupeITSuite.scala | 116 ++++++++++++++++-- .../sql/ppl/CatalystQueryPlanVisitor.java | 2 - .../sql/ppl/utils/DedupeTransformer.java | 26 ++-- 3 files changed, 115 insertions(+), 29 deletions(-) diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupeITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupeITSuite.scala index 96b459769..2f59b6fba 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupeITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLDedupeITSuite.scala @@ -309,38 +309,138 @@ class FlintSparkPPLDedupeITSuite } test("test dedupe 1 category, name - reorder field list won't impact output order") { + val expectedResults: Array[Row] = Array( + Row("A", "X"), + Row("A", "Y"), + Row("B", "Z"), + Row("C", "X"), + Row("D", "Z"), + Row("B", "Y")) + implicit val twoColsRowOrdering: Ordering[Row] = + Ordering.by[Row, (String, String)](row => (row.getAs(0), row.getAs(1))) + val frame1 = sql(s""" | source = $testTable | dedup 1 name, category | """.stripMargin) - val results1: Array[Row] = frame1.drop("id").collect() val frame2 = sql(s""" | source = $testTable | dedup 1 category, name | """.stripMargin) + val results2: Array[Row] = frame2.drop("id").collect() + + assert(results1.sorted.sameElements(results2.sorted)) + assert(results1.sorted.sameElements(expectedResults.sorted)) + } + test( + "test dedupe 1 category, name KEEPEMPTY=true - reorder field list won't impact output order") { + val expectedResults: Array[Row] = Array( + Row("A", "X"), + Row("A", "Y"), + Row("B", "Z"), + Row("C", "X"), + Row("D", "Z"), + Row("B", "Y"), + Row(null, "Y"), + Row("E", null), + Row(null, "X"), + Row("B", null), + Row(null, "Z"), + Row(null, null)) + implicit val nullableTwoColsRowOrdering: Ordering[Row] = + Ordering.by[Row, (String, String)](row => { + val value0 = row.getAs[String](0) + val value1 = row.getAs[String](1) + ( + if (value0 == null) String.valueOf(Int.MaxValue) else value0, + if (value1 == null) String.valueOf(Int.MaxValue) else value1) + }) + + val frame1 = sql(s""" + | source = $testTable | dedup 1 name, category KEEPEMPTY=true + | """.stripMargin) + val results1: Array[Row] = frame1.drop("id").collect() + + val frame2 = sql(s""" + | source = $testTable | dedup 1 category, name KEEPEMPTY=true + | """.stripMargin) val results2: Array[Row] = frame2.drop("id").collect() - implicit val twoColsRowOrdering: Ordering[Row] = - Ordering.by[Row, (String, String)](row => (row.getAs(0), row.getAs(1))) + assert(results1.sorted.sameElements(results2.sorted)) + assert(results1.sorted.sameElements(expectedResults.sorted)) } test("test dedupe 2 category, name - reorder field list won't impact output order") { + val expectedResults: Array[Row] = Array( + Row("A", "X"), + Row("A", "X"), + Row("A", "Y"), + Row("A", "Y"), + Row("B", "Y"), + Row("B", "Z"), + Row("B", "Z"), + Row("C", "X"), + Row("C", "X"), + Row("D", "Z")) + implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](row => { + val value = row.getAs[String](0) + if (value == null) String.valueOf(Int.MaxValue) else value + }) + val frame1 = sql(s""" | source = $testTable | dedup 2 name, category | """.stripMargin) - val results1: Array[Row] = frame1.drop("id").collect() val frame2 = sql(s""" | source = $testTable | dedup 2 category, name | """.stripMargin) + val results2: Array[Row] = frame2.drop("id").collect() + assert(results1.sorted.sameElements(results2.sorted)) + assert(results1.sorted.sameElements(expectedResults.sorted)) + } + + test( + "test dedupe 2 category, name KEEPEMPTY=true - reorder field list won't impact output order") { + val expectedResults: Array[Row] = Array( + Row("A", "X"), + Row("A", "X"), + Row("A", "Y"), + Row("A", "Y"), + Row("B", "Y"), + Row("B", "Z"), + Row("B", "Z"), + Row("C", "X"), + Row("C", "X"), + Row("D", "Z"), + Row(null, "Y"), + Row("E", null), + Row(null, "X"), + Row("B", null), + Row(null, "Z"), + Row(null, null)) + implicit val nullableTwoColsRowOrdering: Ordering[Row] = + Ordering.by[Row, (String, String)](row => { + val value0 = row.getAs[String](0) + val value1 = row.getAs[String](1) + ( + if (value0 == null) String.valueOf(Int.MaxValue) else value0, + if (value1 == null) String.valueOf(Int.MaxValue) else value1) + }) + + val frame1 = sql(s""" + | source = $testTable | dedup 2 name, category KEEPEMPTY=true + | """.stripMargin) + val results1: Array[Row] = frame1.drop("id").collect() + + val frame2 = sql(s""" + | source = $testTable | dedup 2 category, name KEEPEMPTY=true + | """.stripMargin) val results2: Array[Row] = frame2.drop("id").collect() - implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, String](row => { - val value = row.getAs[String](0) - if (value == null) String.valueOf(Int.MaxValue) else value - }) + assert(results1.sorted.sameElements(results2.sorted)) + assert(results1.sorted.sameElements(expectedResults.sorted)) } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java index 7ca726018..73c8677e8 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/CatalystQueryPlanVisitor.java @@ -22,11 +22,9 @@ import org.apache.spark.sql.catalyst.expressions.StringRegexExpression; import org.apache.spark.sql.catalyst.plans.logical.Aggregate; import org.apache.spark.sql.catalyst.plans.logical.DescribeRelation$; -import org.apache.spark.sql.catalyst.plans.logical.Deduplicate; import org.apache.spark.sql.catalyst.plans.logical.Limit; import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.execution.command.DescribeTableCommand; -import org.apache.spark.sql.catalyst.plans.logical.Union; import org.apache.spark.sql.types.DataTypes; import org.apache.spark.sql.types.Metadata; import org.apache.spark.sql.util.CaseInsensitiveStringMap; diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java index bda78c548..0866ca7e9 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/utils/DedupeTransformer.java @@ -15,15 +15,11 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan; import org.apache.spark.sql.catalyst.plans.logical.Union; import org.apache.spark.sql.types.DataTypes; -import org.opensearch.sql.ast.expression.Field; -import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.Dedupe; import org.opensearch.sql.ppl.CatalystPlanContext; import org.opensearch.sql.ppl.CatalystQueryPlanVisitor.ExpressionAnalyzer; import scala.collection.Seq; -import java.util.List; - import static org.opensearch.sql.ppl.utils.DataTypeTransformer.seq; public interface DedupeTransformer { @@ -103,9 +99,9 @@ static LogicalPlan retainMultipleDuplicateEventsAndKeepEmpty( LogicalPlan isNotNullFilter = new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p); // Build Window - visitFieldList(node.getFields(), expressionAnalyzer, context); + node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); Seq partitionSpec = context.retainAllNamedParseExpressions(exp -> exp); - visitFieldList(node.getFields(), expressionAnalyzer, context); + node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); Seq orderSpec = context.retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(exp, true)); NamedExpression rowNumber = WindowSpecTransformer.buildRowNumber(partitionSpec, orderSpec); LogicalPlan window = new org.apache.spark.sql.catalyst.plans.logical.Window( @@ -148,9 +144,9 @@ static LogicalPlan retainMultipleDuplicateEvents( context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Filter(isNotNullExpr, p)); // Build Window - visitFieldList(node.getFields(), expressionAnalyzer, context); + node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); Seq partitionSpec = context.retainAllNamedParseExpressions(exp -> exp); - visitFieldList(node.getFields(), expressionAnalyzer ,context); + node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); Seq orderSpec = context.retainAllNamedParseExpressions(exp -> SortUtils.sortOrder(exp, true)); NamedExpression rowNumber = WindowSpecTransformer.buildRowNumber(partitionSpec, orderSpec); context.apply(p -> new org.apache.spark.sql.catalyst.plans.logical.Window( @@ -167,8 +163,8 @@ static LogicalPlan retainMultipleDuplicateEvents( return context.apply(p -> new DataFrameDropColumns(seq(rowNumber.toAttribute()), p)); } - static Expression buildIsNotNullFilterExpression(Dedupe node, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { - visitFieldList(node.getFields(), expressionAnalyzer, context); + private static Expression buildIsNotNullFilterExpression(Dedupe node, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { + node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); Seq isNotNullExpressions = context.retainAllNamedParseExpressions( org.apache.spark.sql.catalyst.expressions.IsNotNull$.MODULE$::apply); @@ -185,7 +181,7 @@ static Expression buildIsNotNullFilterExpression(Dedupe node, ExpressionAnalyzer } private static Expression buildIsNullFilterExpression(Dedupe node, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { - visitFieldList(node.getFields(), expressionAnalyzer, context); + node.getFields().forEach(field -> expressionAnalyzer.analyze(field, context)); Seq isNullExpressions = context.retainAllNamedParseExpressions( org.apache.spark.sql.catalyst.expressions.IsNull$.MODULE$::apply); @@ -200,12 +196,4 @@ private static Expression buildIsNullFilterExpression(Dedupe node, ExpressionAna } return isNullExpr; } - - static void visitFieldList(List fieldList, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { - fieldList.forEach(field -> visitExpression(field, expressionAnalyzer, context)); - } - - static Expression visitExpression(UnresolvedExpression expression, ExpressionAnalyzer expressionAnalyzer, CatalystPlanContext context) { - return expressionAnalyzer.analyze(expression, context); - } }