diff --git a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFillnullITSuite.scala b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFillnullITSuite.scala index 34f70a8e0..63e9149bb 100644 --- a/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFillnullITSuite.scala +++ b/integ-test/src/integration/scala/org/opensearch/flint/spark/ppl/FlintSparkPPLFillnullITSuite.scala @@ -4,7 +4,12 @@ */ package org.opensearch.flint.spark.ppl -import org.apache.spark.sql.{QueryTest, Row} +import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq + +import org.apache.spark.sql.{AnalysisException, QueryTest, Row} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar} +import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Literal, NamedExpression, SortOrder} +import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, LogicalPlan, Project, Sort, UnaryNode} import org.apache.spark.sql.streaming.StreamTest class FlintSparkPPLFillnullITSuite @@ -36,6 +41,7 @@ class FlintSparkPPLFillnullITSuite | source = $testTable | fillnull value = 0 status_code | """.stripMargin) + assert(frame.columns.sameElements(Array("id", "request_path", "timestamp", "status_code"))) val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array( @@ -48,6 +54,11 @@ class FlintSparkPPLFillnullITSuite // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val expectedPlan = fillNullExpectedPlan(Seq(("status_code", 0))) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } test("test fillnull with various null replacement values and one column") { @@ -55,6 +66,7 @@ class FlintSparkPPLFillnullITSuite | source = $testTable | fillnull fields status_code=101 | """.stripMargin) + assert(frame.columns.sameElements(Array("id", "request_path", "timestamp", "status_code"))) val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array( @@ -67,6 +79,11 @@ class FlintSparkPPLFillnullITSuite // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val expectedPlan = fillNullExpectedPlan(Seq(("status_code", 101))) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } test("test fillnull with one null replacement value and two columns") { @@ -74,6 +91,7 @@ class FlintSparkPPLFillnullITSuite | source = $testTable | fillnull value = '???' request_path, timestamp | fields id, request_path, timestamp | """.stripMargin) + assert(frame.columns.sameElements(Array("id", "request_path", "timestamp"))) val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array( @@ -86,6 +104,19 @@ class FlintSparkPPLFillnullITSuite // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val fillNullPlan = fillNullExpectedPlan( + Seq(("request_path", "???"), ("timestamp", "???")), + addDefaultProject = false) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("request_path"), + UnresolvedAttribute("timestamp")), + fillNullPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) } test("test fillnull with various null replacement values and two columns") { @@ -93,6 +124,7 @@ class FlintSparkPPLFillnullITSuite | source = $testTable | fillnull fields request_path='/not_found', timestamp='*' | fields id, request_path, timestamp | """.stripMargin) + assert(frame.columns.sameElements(Array("id", "request_path", "timestamp"))) val results: Array[Row] = frame.collect() val expectedResults: Array[Row] = Array( @@ -105,5 +137,156 @@ class FlintSparkPPLFillnullITSuite // Compare the results implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0)) assert(results.sorted.sameElements(expectedResults.sorted)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val fillNullPlan = fillNullExpectedPlan( + Seq(("request_path", "/not_found"), ("timestamp", "*")), + addDefaultProject = false) + val expectedPlan = Project( + Seq( + UnresolvedAttribute("id"), + UnresolvedAttribute("request_path"), + UnresolvedAttribute("timestamp")), + fillNullPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test fillnull with one null replacement value and stats and sort command") { + val frame = sql(s""" + | source = $testTable | fillnull value = 500 status_code + | | stats count(status_code) by status_code, request_path + | | sort request_path, status_code + | """.stripMargin) + + assert(frame.columns.sameElements(Array("count(status_code)", "status_code", "request_path"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, 200, null), + Row(1, 301, null), + Row(1, 500, "/about"), + Row(1, 500, "/contact"), + Row(1, 200, "/home"), + Row(1, 403, "/home")) + // Compare the results + assert(results.sameElements(expectedResults)) + + // Retrieve the logical plan + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val fillNullPlan = fillNullExpectedPlan(Seq(("status_code", 500)), addDefaultProject = false) + val aggregateExpressions = + Seq( + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "count(status_code)")(), + Alias(UnresolvedAttribute("status_code"), "status_code")(), + Alias(UnresolvedAttribute("request_path"), "request_path")()) + val aggregatePlan = Aggregate( + Seq( + Alias(UnresolvedAttribute("status_code"), "status_code")(), + Alias(UnresolvedAttribute("request_path"), "request_path")()), + aggregateExpressions, + fillNullPlan) + val sortPlan = Sort( + Seq( + SortOrder(UnresolvedAttribute("request_path"), Ascending), + SortOrder(UnresolvedAttribute("status_code"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(seq(UnresolvedStar(None)), sortPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test fillnull with various null replacement value and stats and sort command") { + val frame = sql(s""" + | source = $testTable | fillnull fields status_code = 500, request_path = '/home' + | | stats count(status_code) by status_code, request_path + | | sort request_path, status_code + | """.stripMargin) + + assert(frame.columns.sameElements(Array("count(status_code)", "status_code", "request_path"))) + val results: Array[Row] = frame.collect() + val expectedResults: Array[Row] = + Array( + Row(1, 500, "/about"), + Row(1, 500, "/contact"), + Row(2, 200, "/home"), + Row(1, 301, "/home"), + Row(1, 403, "/home")) + // Compare the results + assert(results.sameElements(expectedResults)) + + val logicalPlan: LogicalPlan = frame.queryExecution.logical + val fillNullPlan = fillNullExpectedPlan( + Seq(("status_code", 500), ("request_path", "/home")), + addDefaultProject = false) + val aggregateExpressions = + Seq( + Alias( + UnresolvedFunction( + Seq("COUNT"), + Seq(UnresolvedAttribute("status_code")), + isDistinct = false), + "count(status_code)")(), + Alias(UnresolvedAttribute("status_code"), "status_code")(), + Alias(UnresolvedAttribute("request_path"), "request_path")()) + val aggregatePlan = Aggregate( + Seq( + Alias(UnresolvedAttribute("status_code"), "status_code")(), + Alias(UnresolvedAttribute("request_path"), "request_path")()), + aggregateExpressions, + fillNullPlan) + val sortPlan = Sort( + Seq( + SortOrder(UnresolvedAttribute("request_path"), Ascending), + SortOrder(UnresolvedAttribute("status_code"), Ascending)), + global = true, + aggregatePlan) + val expectedPlan = Project(seq(UnresolvedStar(None)), sortPlan) + comparePlans(logicalPlan, expectedPlan, checkAnalysis = false) + } + + test("test fillnull with one null replacement value and missing columns") { + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | fillnull value = '!!!' + | """.stripMargin)) + + assert(ex.getMessage().contains("Syntax error ")) + } + + test("test fillnull with various null replacement values and missing columns") { + val ex = intercept[AnalysisException](sql(s""" + | source = $testTable | fillnull fields + | """.stripMargin)) + + assert(ex.getMessage().contains("Syntax error ")) + } + + private def fillNullExpectedPlan( + nullReplacements: Seq[(String, Any)], + addDefaultProject: Boolean = true): LogicalPlan = { + val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test")) + val renameProjectList = UnresolvedStar(None) +: nullReplacements.map { + case (nullableColumn, nullReplacement) => + Alias( + UnresolvedFunction( + "coalesce", + Seq(UnresolvedAttribute(nullableColumn), Literal(nullReplacement)), + isDistinct = false), + nullableColumn)() + } + val renameProject = Project(renameProjectList, table) + val droppedColumns = + nullReplacements.map(_._1).map(columnName => UnresolvedAttribute(columnName)) + val dropSourceColumn = DataFrameDropColumns(droppedColumns, renameProject) + if (addDefaultProject) { + Project(seq(UnresolvedStar(None)), dropSourceColumn) + } else { + dropSourceColumn + } } } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FillNull.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FillNull.java index 16d15894c..d1bb9df66 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FillNull.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ast/tree/FillNull.java @@ -1,5 +1,8 @@ package org.opensearch.sql.ast.tree; +import lombok.Getter; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; import org.opensearch.sql.ast.AbstractNodeVisitor; import org.opensearch.sql.ast.Node; import org.opensearch.sql.ast.expression.Field; @@ -8,76 +11,58 @@ import java.util.List; import java.util.Objects; import java.util.stream.Collectors; -import java.util.stream.Stream; -public class FillNull extends UnresolvedPlan { ; +@RequiredArgsConstructor +public class FillNull extends UnresolvedPlan { + @Getter + @RequiredArgsConstructor public static class NullableFieldFill { + @NonNull private final Field nullableFieldReference; + @NonNull private final Literal replaceNullWithMe; + } - public NullableFieldFill(Field nullableFieldReference, Literal replaceNullWithMe) { - this.nullableFieldReference = Objects.requireNonNull(nullableFieldReference, "Field to replace is required"); - this.replaceNullWithMe = Objects.requireNonNull(replaceNullWithMe, "Null replacement is required"); - } + public interface ContainNullableFieldFill { + List getNullFieldFill(); - public Field getNullableFieldReference() { - return nullableFieldReference; + static ContainNullableFieldFill ofVariousValue(List replacements) { + return new VariousValueNullFill(replacements); } - public Literal getReplaceNullWithMe() { - return replaceNullWithMe; + static ContainNullableFieldFill ofSameValue(Literal replaceNullWithMe, List nullableFieldReferences) { + return new SameValueNullFill(replaceNullWithMe, nullableFieldReferences); } } - private interface ContainNullableFieldFill { - Stream getNullFieldFill(); - } - - public static class SameValueNullFill implements ContainNullableFieldFill { - private final List replacements; + private static class SameValueNullFill implements ContainNullableFieldFill { + @Getter(onMethod_ = @Override) + private final List nullFieldFill; public SameValueNullFill(Literal replaceNullWithMe, List nullableFieldReferences) { Objects.requireNonNull(replaceNullWithMe, "Null replacement is required"); - this.replacements = Objects.requireNonNull(nullableFieldReferences, "Nullable field reference is required") + this.nullFieldFill = Objects.requireNonNull(nullableFieldReferences, "Nullable field reference is required") .stream() .map(nullableReference -> new NullableFieldFill(nullableReference, replaceNullWithMe)) .collect(Collectors.toList()); } - - @Override - public Stream getNullFieldFill() { - return replacements.stream(); - } } - public static class VariousValueNullFill implements ContainNullableFieldFill { - private final List replacements; - - public VariousValueNullFill(List replacements) { - this.replacements = replacements; - } - - @Override - public Stream getNullFieldFill() { - return replacements.stream(); - } + @RequiredArgsConstructor + private static class VariousValueNullFill implements ContainNullableFieldFill { + @NonNull + @Getter(onMethod_ = @Override) + private final List nullFieldFill; } private UnresolvedPlan child; - private final SameValueNullFill sameValueNullFill; - private final VariousValueNullFill variousValueNullFill; - public FillNull(SameValueNullFill sameValueNullFill, VariousValueNullFill variousValueNullFill) { - this.sameValueNullFill = sameValueNullFill; - this.variousValueNullFill = variousValueNullFill; - } + @NonNull + private final ContainNullableFieldFill containNullableFieldFill; public List getNullableFieldFills() { - return Stream.of(sameValueNullFill, variousValueNullFill) - .filter(Objects::nonNull) - .flatMap(ContainNullableFieldFill::getNullFieldFill) - .collect(Collectors.toList()); + return containNullableFieldFill.getNullFieldFill(); } @Override @@ -88,7 +73,6 @@ public UnresolvedPlan attach(UnresolvedPlan child) { @Override public List getChild() { - return child == null ? List.of() : List.of(child); } diff --git a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java index 67632f5e1..0a1218324 100644 --- a/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java +++ b/ppl-spark-integration/src/main/java/org/opensearch/sql/ppl/parser/AstBuilder.java @@ -32,8 +32,6 @@ import org.opensearch.sql.ast.expression.UnresolvedExpression; import org.opensearch.sql.ast.tree.*; import org.opensearch.sql.ast.tree.FillNull.NullableFieldFill; -import org.opensearch.sql.ast.tree.FillNull.SameValueNullFill; -import org.opensearch.sql.ast.tree.FillNull.VariousValueNullFill; import org.opensearch.sql.common.antlr.SyntaxCheckException; import org.opensearch.sql.ast.tree.Aggregation; import org.opensearch.sql.ast.tree.Correlation; @@ -68,6 +66,8 @@ import static java.util.Collections.emptyList; import static java.util.Collections.emptyMap; +import static org.opensearch.sql.ast.tree.FillNull.ContainNullableFieldFill.ofSameValue; +import static org.opensearch.sql.ast.tree.FillNull.ContainNullableFieldFill.ofVariousValue; /** Class of building the AST. Refines the visit path and build the AST nodes */ @@ -519,8 +519,7 @@ public UnresolvedPlan visitFillnullCommand(OpenSearchPPLParser.FillnullCommandCo .map(this::internalVisitExpression) .map(Field.class::cast) .collect(Collectors.toList()); - SameValueNullFill sameValueNullFill = new SameValueNullFill(replaceNullWithMe, fieldsToReplace); - return new FillNull(sameValueNullFill, null); + return new FillNull(ofSameValue(replaceNullWithMe, fieldsToReplace)); } else if (variousValuesContext != null) { List nullableFieldFills = IntStream.range(0, variousValuesContext.nullableField().size()) .mapToObj(index -> { @@ -530,7 +529,7 @@ public UnresolvedPlan visitFillnullCommand(OpenSearchPPLParser.FillnullCommandCo return new NullableFieldFill(nullableFieldReference, replaceNullWithMe); }) .collect(Collectors.toList()); - return new FillNull(null, new VariousValueNullFill(nullableFieldFills)); + return new FillNull(ofVariousValue(nullableFieldFills)); } else { throw new SyntaxCheckException("Invalid fillnull command"); } diff --git a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanRenameCommandTranslatorTestSuite.scala b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFillnullCommandTranslatorTestSuite.scala similarity index 99% rename from ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanRenameCommandTranslatorTestSuite.scala rename to ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFillnullCommandTranslatorTestSuite.scala index 9e94581e8..ead17fda2 100644 --- a/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanRenameCommandTranslatorTestSuite.scala +++ b/ppl-spark-integration/src/test/scala/org/opensearch/flint/spark/ppl/PPLLogicalPlanFillnullCommandTranslatorTestSuite.scala @@ -16,7 +16,7 @@ import org.apache.spark.sql.catalyst.expressions.{Alias, Literal, NamedExpressio import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.catalyst.plans.logical.{DataFrameDropColumns, Project} -class PPLLogicalPlanRenameCommandTranslatorTestSuite +class PPLLogicalPlanFillnullCommandTranslatorTestSuite extends SparkFunSuite with PlanTest with LogicalPlanTestUtils