Skip to content

Commit

Permalink
Introduced more tests for the Fillnull command, and code preparation …
Browse files Browse the repository at this point in the history
…for the review.

Signed-off-by: Lukasz Soszynski <[email protected]>
  • Loading branch information
lukasz-soszynski-eliatra committed Oct 7, 2024
1 parent 735eb61 commit 8151931
Show file tree
Hide file tree
Showing 4 changed files with 217 additions and 51 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
*/
package org.opensearch.flint.spark.ppl

import org.apache.spark.sql.{QueryTest, Row}
import org.opensearch.sql.ppl.utils.DataTypeTransformer.seq

import org.apache.spark.sql.{AnalysisException, QueryTest, Row}
import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedFunction, UnresolvedRelation, UnresolvedStar}
import org.apache.spark.sql.catalyst.expressions.{Alias, Ascending, Literal, NamedExpression, SortOrder}
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, DataFrameDropColumns, LogicalPlan, Project, Sort, UnaryNode}
import org.apache.spark.sql.streaming.StreamTest

class FlintSparkPPLFillnullITSuite
Expand Down Expand Up @@ -36,6 +41,7 @@ class FlintSparkPPLFillnullITSuite
| source = $testTable | fillnull value = 0 status_code
| """.stripMargin)

assert(frame.columns.sameElements(Array("id", "request_path", "timestamp", "status_code")))
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Expand All @@ -48,13 +54,19 @@ class FlintSparkPPLFillnullITSuite
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val expectedPlan = fillNullExpectedPlan(Seq(("status_code", 0)))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test fillnull with various null replacement values and one column") {
val frame = sql(s"""
| source = $testTable | fillnull fields status_code=101
| """.stripMargin)

assert(frame.columns.sameElements(Array("id", "request_path", "timestamp", "status_code")))
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Expand All @@ -67,13 +79,19 @@ class FlintSparkPPLFillnullITSuite
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val expectedPlan = fillNullExpectedPlan(Seq(("status_code", 101)))
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test fillnull with one null replacement value and two columns") {
val frame = sql(s"""
| source = $testTable | fillnull value = '???' request_path, timestamp | fields id, request_path, timestamp
| """.stripMargin)

assert(frame.columns.sameElements(Array("id", "request_path", "timestamp")))
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Expand All @@ -86,13 +104,27 @@ class FlintSparkPPLFillnullITSuite
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val fillNullPlan = fillNullExpectedPlan(
Seq(("request_path", "???"), ("timestamp", "???")),
addDefaultProject = false)
val expectedPlan = Project(
Seq(
UnresolvedAttribute("id"),
UnresolvedAttribute("request_path"),
UnresolvedAttribute("timestamp")),
fillNullPlan)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test fillnull with various null replacement values and two columns") {
val frame = sql(s"""
| source = $testTable | fillnull fields request_path='/not_found', timestamp='*' | fields id, request_path, timestamp
| """.stripMargin)

assert(frame.columns.sameElements(Array("id", "request_path", "timestamp")))
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Expand All @@ -105,5 +137,156 @@ class FlintSparkPPLFillnullITSuite
// Compare the results
implicit val rowOrdering: Ordering[Row] = Ordering.by[Row, Int](_.getAs[Int](0))
assert(results.sorted.sameElements(expectedResults.sorted))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val fillNullPlan = fillNullExpectedPlan(
Seq(("request_path", "/not_found"), ("timestamp", "*")),
addDefaultProject = false)
val expectedPlan = Project(
Seq(
UnresolvedAttribute("id"),
UnresolvedAttribute("request_path"),
UnresolvedAttribute("timestamp")),
fillNullPlan)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test fillnull with one null replacement value and stats and sort command") {
val frame = sql(s"""
| source = $testTable | fillnull value = 500 status_code
| | stats count(status_code) by status_code, request_path
| | sort request_path, status_code
| """.stripMargin)

assert(frame.columns.sameElements(Array("count(status_code)", "status_code", "request_path")))
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row(1, 200, null),
Row(1, 301, null),
Row(1, 500, "/about"),
Row(1, 500, "/contact"),
Row(1, 200, "/home"),
Row(1, 403, "/home"))
// Compare the results
assert(results.sameElements(expectedResults))

// Retrieve the logical plan
val logicalPlan: LogicalPlan = frame.queryExecution.logical
val fillNullPlan = fillNullExpectedPlan(Seq(("status_code", 500)), addDefaultProject = false)
val aggregateExpressions =
Seq(
Alias(
UnresolvedFunction(
Seq("COUNT"),
Seq(UnresolvedAttribute("status_code")),
isDistinct = false),
"count(status_code)")(),
Alias(UnresolvedAttribute("status_code"), "status_code")(),
Alias(UnresolvedAttribute("request_path"), "request_path")())
val aggregatePlan = Aggregate(
Seq(
Alias(UnresolvedAttribute("status_code"), "status_code")(),
Alias(UnresolvedAttribute("request_path"), "request_path")()),
aggregateExpressions,
fillNullPlan)
val sortPlan = Sort(
Seq(
SortOrder(UnresolvedAttribute("request_path"), Ascending),
SortOrder(UnresolvedAttribute("status_code"), Ascending)),
global = true,
aggregatePlan)
val expectedPlan = Project(seq(UnresolvedStar(None)), sortPlan)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test fillnull with various null replacement value and stats and sort command") {
val frame = sql(s"""
| source = $testTable | fillnull fields status_code = 500, request_path = '/home'
| | stats count(status_code) by status_code, request_path
| | sort request_path, status_code
| """.stripMargin)

assert(frame.columns.sameElements(Array("count(status_code)", "status_code", "request_path")))
val results: Array[Row] = frame.collect()
val expectedResults: Array[Row] =
Array(
Row(1, 500, "/about"),
Row(1, 500, "/contact"),
Row(2, 200, "/home"),
Row(1, 301, "/home"),
Row(1, 403, "/home"))
// Compare the results
assert(results.sameElements(expectedResults))

val logicalPlan: LogicalPlan = frame.queryExecution.logical
val fillNullPlan = fillNullExpectedPlan(
Seq(("status_code", 500), ("request_path", "/home")),
addDefaultProject = false)
val aggregateExpressions =
Seq(
Alias(
UnresolvedFunction(
Seq("COUNT"),
Seq(UnresolvedAttribute("status_code")),
isDistinct = false),
"count(status_code)")(),
Alias(UnresolvedAttribute("status_code"), "status_code")(),
Alias(UnresolvedAttribute("request_path"), "request_path")())
val aggregatePlan = Aggregate(
Seq(
Alias(UnresolvedAttribute("status_code"), "status_code")(),
Alias(UnresolvedAttribute("request_path"), "request_path")()),
aggregateExpressions,
fillNullPlan)
val sortPlan = Sort(
Seq(
SortOrder(UnresolvedAttribute("request_path"), Ascending),
SortOrder(UnresolvedAttribute("status_code"), Ascending)),
global = true,
aggregatePlan)
val expectedPlan = Project(seq(UnresolvedStar(None)), sortPlan)
comparePlans(logicalPlan, expectedPlan, checkAnalysis = false)
}

test("test fillnull with one null replacement value and missing columns") {
val ex = intercept[AnalysisException](sql(s"""
| source = $testTable | fillnull value = '!!!'
| """.stripMargin))

assert(ex.getMessage().contains("Syntax error "))
}

test("test fillnull with various null replacement values and missing columns") {
val ex = intercept[AnalysisException](sql(s"""
| source = $testTable | fillnull fields
| """.stripMargin))

assert(ex.getMessage().contains("Syntax error "))
}

private def fillNullExpectedPlan(
nullReplacements: Seq[(String, Any)],
addDefaultProject: Boolean = true): LogicalPlan = {
val table = UnresolvedRelation(Seq("spark_catalog", "default", "flint_ppl_test"))
val renameProjectList = UnresolvedStar(None) +: nullReplacements.map {
case (nullableColumn, nullReplacement) =>
Alias(
UnresolvedFunction(
"coalesce",
Seq(UnresolvedAttribute(nullableColumn), Literal(nullReplacement)),
isDistinct = false),
nullableColumn)()
}
val renameProject = Project(renameProjectList, table)
val droppedColumns =
nullReplacements.map(_._1).map(columnName => UnresolvedAttribute(columnName))
val dropSourceColumn = DataFrameDropColumns(droppedColumns, renameProject)
if (addDefaultProject) {
Project(seq(UnresolvedStar(None)), dropSourceColumn)
} else {
dropSourceColumn
}
}
}
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package org.opensearch.sql.ast.tree;

import lombok.Getter;
import lombok.NonNull;
import lombok.RequiredArgsConstructor;
import org.opensearch.sql.ast.AbstractNodeVisitor;
import org.opensearch.sql.ast.Node;
import org.opensearch.sql.ast.expression.Field;
Expand All @@ -8,76 +11,58 @@
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;

public class FillNull extends UnresolvedPlan { ;
@RequiredArgsConstructor
public class FillNull extends UnresolvedPlan {

@Getter
@RequiredArgsConstructor
public static class NullableFieldFill {
@NonNull
private final Field nullableFieldReference;
@NonNull
private final Literal replaceNullWithMe;
}

public NullableFieldFill(Field nullableFieldReference, Literal replaceNullWithMe) {
this.nullableFieldReference = Objects.requireNonNull(nullableFieldReference, "Field to replace is required");
this.replaceNullWithMe = Objects.requireNonNull(replaceNullWithMe, "Null replacement is required");
}
public interface ContainNullableFieldFill {
List<NullableFieldFill> getNullFieldFill();

public Field getNullableFieldReference() {
return nullableFieldReference;
static ContainNullableFieldFill ofVariousValue(List<NullableFieldFill> replacements) {
return new VariousValueNullFill(replacements);
}

public Literal getReplaceNullWithMe() {
return replaceNullWithMe;
static ContainNullableFieldFill ofSameValue(Literal replaceNullWithMe, List<Field> nullableFieldReferences) {
return new SameValueNullFill(replaceNullWithMe, nullableFieldReferences);
}
}

private interface ContainNullableFieldFill {
Stream<NullableFieldFill> getNullFieldFill();
}

public static class SameValueNullFill implements ContainNullableFieldFill {
private final List<NullableFieldFill> replacements;
private static class SameValueNullFill implements ContainNullableFieldFill {
@Getter(onMethod_ = @Override)
private final List<NullableFieldFill> nullFieldFill;

public SameValueNullFill(Literal replaceNullWithMe, List<Field> nullableFieldReferences) {
Objects.requireNonNull(replaceNullWithMe, "Null replacement is required");
this.replacements = Objects.requireNonNull(nullableFieldReferences, "Nullable field reference is required")
this.nullFieldFill = Objects.requireNonNull(nullableFieldReferences, "Nullable field reference is required")
.stream()
.map(nullableReference -> new NullableFieldFill(nullableReference, replaceNullWithMe))
.collect(Collectors.toList());
}

@Override
public Stream<NullableFieldFill> getNullFieldFill() {
return replacements.stream();
}
}

public static class VariousValueNullFill implements ContainNullableFieldFill {
private final List<NullableFieldFill> replacements;

public VariousValueNullFill(List<NullableFieldFill> replacements) {
this.replacements = replacements;
}

@Override
public Stream<NullableFieldFill> getNullFieldFill() {
return replacements.stream();
}
@RequiredArgsConstructor
private static class VariousValueNullFill implements ContainNullableFieldFill {
@NonNull
@Getter(onMethod_ = @Override)
private final List<NullableFieldFill> nullFieldFill;
}

private UnresolvedPlan child;
private final SameValueNullFill sameValueNullFill;
private final VariousValueNullFill variousValueNullFill;

public FillNull(SameValueNullFill sameValueNullFill, VariousValueNullFill variousValueNullFill) {
this.sameValueNullFill = sameValueNullFill;
this.variousValueNullFill = variousValueNullFill;
}
@NonNull
private final ContainNullableFieldFill containNullableFieldFill;

public List<NullableFieldFill> getNullableFieldFills() {
return Stream.of(sameValueNullFill, variousValueNullFill)
.filter(Objects::nonNull)
.flatMap(ContainNullableFieldFill::getNullFieldFill)
.collect(Collectors.toList());
return containNullableFieldFill.getNullFieldFill();
}

@Override
Expand All @@ -88,7 +73,6 @@ public UnresolvedPlan attach(UnresolvedPlan child) {

@Override
public List<? extends Node> getChild() {

return child == null ? List.of() : List.of(child);
}

Expand Down
Loading

0 comments on commit 8151931

Please sign in to comment.