Skip to content

Commit

Permalink
add more unit tests and increase test coverage
Browse files Browse the repository at this point in the history
Signed-off-by: Lantao Jin <[email protected]>
  • Loading branch information
LantaoJin committed May 27, 2024
1 parent 2f5eea8 commit b0fbd8d
Show file tree
Hide file tree
Showing 14 changed files with 250 additions and 21 deletions.
4 changes: 4 additions & 0 deletions core/src/main/java/org/opensearch/sql/expression/DSL.java
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,10 @@ public static Aggregator max(Expression... expressions) {
return aggregate(BuiltinFunctionName.MAX, expressions);
}

/**
* OpenSearch uses T-Digest to approximate percentile, so PERCENTILE and PERCENTILE_APPROX are the
* same function.
*/
public static Aggregator percentile(Expression... expressions) {
return percentileApprox(expressions);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -247,18 +247,34 @@ private static DefaultFunctionResolver percentileApprox() {
new FunctionSignature(functionName, ImmutableList.of(INTEGER, DOUBLE)),
(functionProperties, arguments) ->
PercentileApproximateAggregator.percentileApprox(arguments, INTEGER))
.put(
new FunctionSignature(functionName, ImmutableList.of(INTEGER, DOUBLE, DOUBLE)),
(functionProperties, arguments) ->
PercentileApproximateAggregator.percentileApprox(arguments, INTEGER))
.put(
new FunctionSignature(functionName, ImmutableList.of(LONG, DOUBLE)),
(functionProperties, arguments) ->
PercentileApproximateAggregator.percentileApprox(arguments, LONG))
.put(
new FunctionSignature(functionName, ImmutableList.of(LONG, DOUBLE, DOUBLE)),
(functionProperties, arguments) ->
PercentileApproximateAggregator.percentileApprox(arguments, LONG))
.put(
new FunctionSignature(functionName, ImmutableList.of(FLOAT, DOUBLE)),
(functionProperties, arguments) ->
PercentileApproximateAggregator.percentileApprox(arguments, FLOAT))
.put(
new FunctionSignature(functionName, ImmutableList.of(FLOAT, DOUBLE, DOUBLE)),
(functionProperties, arguments) ->
PercentileApproximateAggregator.percentileApprox(arguments, FLOAT))
.put(
new FunctionSignature(functionName, ImmutableList.of(DOUBLE, DOUBLE)),
(functionProperties, arguments) ->
PercentileApproximateAggregator.percentileApprox(arguments, DOUBLE))
.put(
new FunctionSignature(functionName, ImmutableList.of(DOUBLE, DOUBLE, DOUBLE)),
(functionProperties, arguments) ->
PercentileApproximateAggregator.percentileApprox(arguments, DOUBLE))
.build());
return functionResolver;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.function.BuiltinFunctionName;

/** Aggregator to calculate approximate percentile. */
public class PercentileApproximateAggregator
extends Aggregator<PercentileApproximateAggregator.PercentileApproximateState> {

Expand All @@ -26,11 +27,21 @@ public static Aggregator percentileApprox(List<Expression> arguments, ExprCoreTy

public PercentileApproximateAggregator(List<Expression> arguments, ExprCoreType returnType) {
super(BuiltinFunctionName.PERCENTILE_APPROX.getName(), arguments, returnType);
if (!ExprCoreType.numberTypes().contains(returnType)) {
throw new IllegalArgumentException(
String.format("percentile aggregation over %s type is not supported", returnType));
}
}

@Override
public PercentileApproximateState create() {
return new PercentileApproximateState(getArguments().get(1).valueOf().doubleValue());
if (getArguments().size() == 2) {
return new PercentileApproximateState(getArguments().get(1).valueOf().doubleValue());
} else {
return new PercentileApproximateState(
getArguments().get(1).valueOf().doubleValue(),
getArguments().get(2).valueOf().doubleValue());
}
}

@Override
Expand All @@ -44,18 +55,35 @@ public String toString() {
return StringUtils.format("%s(%s)", "percentile", format(getArguments()));
}

/**
* PercentileApproximateState is used to store the AVLTreeDigest state for percentile estimation.
*/
protected static class PercentileApproximateState extends AVLTreeDigest
implements AggregationState {

private static final double DEFAULT_COMPRESSION = 100.0;
private final double p;
// The compression level for the AVLTreeDigest, keep the same default value as OpenSearch core.
public static final double DEFAULT_COMPRESSION = 100.0;
private final double quantileRatio;

PercentileApproximateState(double quantile) {
super(DEFAULT_COMPRESSION);
if (quantile < 0.0 || quantile > 100.0) {
throw new IllegalArgumentException("out of bounds quantile value, must be in [0, 100]");
}
this.p = quantile / 100.0;
this.quantileRatio = quantile / 100.0;
}

/**
* Constructor for specifying both quantile and compression level.
*
* @param quantile the quantile to compute, must be in [0, 100]
* @param compression the compression factor of the t-digest sketches used
*/
PercentileApproximateState(double quantile, double compression) {
super(compression);
if (quantile < 0.0 || quantile > 100.0) {
throw new IllegalArgumentException("out of bounds quantile value, must be in [0, 100]");
}
this.quantileRatio = quantile / 100.0;
}

public void evaluate(ExprValue value) {
Expand All @@ -64,7 +92,7 @@ public void evaluate(ExprValue value) {

@Override
public ExprValue result() {
return this.size() == 0 ? ExprNullValue.of() : doubleValue(this.quantile(p));
return this.size() == 0 ? ExprNullValue.of() : doubleValue(this.quantile(quantileRatio));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ public enum BuiltinFunctionName {
STDDEV_POP(FunctionName.of("stddev_pop")),
// take top documents from aggregation bucket.
TAKE(FunctionName.of("take")),
// t-digest percentile
// t-digest percentile which is used in OpenSearch core by default.
PERCENTILE_APPROX(FunctionName.of("percentile_approx")),
// Not always an aggregation query
NESTED(FunctionName.of("nested")),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@

package org.opensearch.sql.expression.aggregation;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.when;
import static org.opensearch.sql.data.model.ExprValueUtils.*;
Expand All @@ -27,6 +26,7 @@
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.sql.data.model.ExprValue;
import org.opensearch.sql.data.model.ExprValueUtils;
import org.opensearch.sql.exception.ExpressionEvaluationException;
import org.opensearch.sql.expression.DSL;
import org.opensearch.sql.expression.Expression;
Expand Down Expand Up @@ -55,6 +55,29 @@ public void test_percentile_field_expression() {
assertEquals(3.0, result.value());
}

@Test
public void test_percentile_field_expression_with_user_defined_compression() {
ExprValue result =
aggregation(
DSL.percentile(DSL.ref("integer_value", INTEGER), DSL.literal(50), DSL.literal(0.1)),
tuples);
assertEquals(2.5, result.value());
result =
aggregation(
DSL.percentile(DSL.ref("long_value", LONG), DSL.literal(50), DSL.literal(0.1)), tuples);
assertEquals(2.5, result.value());
result =
aggregation(
DSL.percentile(DSL.ref("double_value", DOUBLE), DSL.literal(50), DSL.literal(0.1)),
tuples);
assertEquals(2.5, result.value());
result =
aggregation(
DSL.percentile(DSL.ref("float_value", FLOAT), DSL.literal(50), DSL.literal(0.1)),
tuples);
assertEquals(2.5, result.value());
}

@Test
public void test_percentile_expression() {
ExprValue result =
Expand Down Expand Up @@ -129,14 +152,39 @@ public void test_percentile_value() {

@Test
public void test_percentile_with_invalid_size() {
IllegalArgumentException exception =
var exception =
assertThrows(
IllegalArgumentException.class,
() ->
aggregation(
DSL.percentile(DSL.ref("double_value", DOUBLE), DSL.literal(-1)), tuples));
assertEquals("out of bounds quantile value, must be in [0, 100]", exception.getMessage());
ExpressionEvaluationException exception2 =
exception =
assertThrows(
IllegalArgumentException.class,
() ->
aggregation(
DSL.percentile(DSL.ref("double_value", DOUBLE), DSL.literal(200)), tuples));
assertEquals("out of bounds quantile value, must be in [0, 100]", exception.getMessage());
exception =
assertThrows(
IllegalArgumentException.class,
() ->
aggregation(
DSL.percentile(
DSL.ref("double_value", DOUBLE), DSL.literal(-1), DSL.literal(100)),
tuples));
assertEquals("out of bounds quantile value, must be in [0, 100]", exception.getMessage());
exception =
assertThrows(
IllegalArgumentException.class,
() ->
aggregation(
DSL.percentile(
DSL.ref("double_value", DOUBLE), DSL.literal(200), DSL.literal(100)),
tuples));
assertEquals("out of bounds quantile value, must be in [0, 100]", exception.getMessage());
var exception2 =
assertThrows(
ExpressionEvaluationException.class,
() ->
Expand All @@ -145,11 +193,83 @@ public void test_percentile_with_invalid_size() {
tuples));
assertEquals(
"percentile_approx function expected"
+ " {[INTEGER,DOUBLE],[LONG,DOUBLE],[FLOAT,DOUBLE],[DOUBLE,DOUBLE]}, but get"
+ " [DOUBLE,STRING]",
+ " {[INTEGER,DOUBLE],[INTEGER,DOUBLE,DOUBLE],[LONG,DOUBLE],[LONG,DOUBLE,DOUBLE],"
+ "[FLOAT,DOUBLE],[FLOAT,DOUBLE,DOUBLE],[DOUBLE,DOUBLE],[DOUBLE,DOUBLE,DOUBLE]},"
+ " but get [DOUBLE,STRING]",
exception2.getMessage());
}

@Test
public void test_arithmetic_expression() {
ExprValue result =
aggregation(
DSL.percentile(
DSL.multiply(
DSL.ref("integer_value", INTEGER),
DSL.literal(ExprValueUtils.integerValue(10))),
DSL.literal(50)),
tuples);
assertEquals(30.0, result.value());
}

@Test
public void test_filtered_percentile() {
ExprValue result =
aggregation(
DSL.percentile(DSL.ref("integer_value", INTEGER), DSL.literal(50))
.condition(DSL.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))),
tuples);
assertEquals(3.0, result.value());
}

@Test
public void test_with_missing() {
ExprValue result =
aggregation(
DSL.percentile(DSL.ref("integer_value", INTEGER), DSL.literal(50)),
tuples_with_null_and_missing);
assertEquals(2.0, result.value());
}

@Test
public void test_with_null() {
ExprValue result =
aggregation(
DSL.percentile(DSL.ref("double_value", DOUBLE), DSL.literal(50)),
tuples_with_null_and_missing);
assertEquals(4.0, result.value());
}

@Test
public void test_with_all_missing_or_null() {
ExprValue result =
aggregation(
DSL.percentile(DSL.ref("integer_value", INTEGER), DSL.literal(50)),
tuples_with_all_null_or_missing);
assertTrue(result.isNull());
}

@Test
public void test_unsupported_type() {
var exception =
assertThrows(
IllegalArgumentException.class,
() ->
new PercentileApproximateAggregator(
List.of(DSL.ref("string", STRING), DSL.ref("string", STRING)), STRING));
assertEquals(
"percentile aggregation over STRING type is not supported", exception.getMessage());
}

@Test
public void test_to_string() {
Aggregator aggregator = DSL.percentile(DSL.ref("integer_value", INTEGER), DSL.literal(50));
assertEquals("percentile(integer_value,50)", aggregator.toString());
aggregator =
DSL.percentile(DSL.ref("integer_value", INTEGER), DSL.literal(50), DSL.literal(0.1));
assertEquals("percentile(integer_value,50,0.1)", aggregator.toString());
}

private ExprValue[] percentiles(ExprValue value, ExprValue... values) {
return new ExprValue[] {
percentile(DSL.literal(1.0), value, values),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@
import static org.mockito.Mockito.when;
import static org.opensearch.sql.data.model.ExprValueUtils.integerValue;
import static org.opensearch.sql.data.model.ExprValueUtils.longValue;
import static org.opensearch.sql.data.type.ExprCoreType.INTEGER;
import static org.opensearch.sql.data.type.ExprCoreType.LONG;
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
import static org.opensearch.sql.data.type.ExprCoreType.*;
import static org.opensearch.sql.planner.logical.LogicalPlanDSL.aggregation;
import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter;
import static org.opensearch.sql.planner.logical.LogicalPlanDSL.highlight;
Expand Down Expand Up @@ -180,6 +178,22 @@ void table_scan_builder_support_aggregation_push_down_can_apply_its_rule() {
ImmutableList.of(DSL.named("longV", DSL.ref("longV", LONG))))));
}

@Test
void table_scan_builder_support_percentile_aggregation_push_down_can_apply_its_rule() {
when(tableScanBuilder.pushDownAggregation(any())).thenReturn(true);

assertEquals(
tableScanBuilder,
optimize(
aggregation(
relation("schema", table),
ImmutableList.of(
DSL.named(
"PERCENTILE(intV, 1)",
DSL.percentile(DSL.ref("intV", INTEGER), DSL.ref("percentile", DOUBLE)))),
ImmutableList.of(DSL.named("longV", DSL.ref("longV", LONG))))));
}

@Test
void table_scan_builder_support_sort_push_down_can_apply_its_rule() {
when(tableScanBuilder.pushDownSort(any())).thenReturn(true);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -190,7 +190,6 @@ public void testStatsAliasedSpan() throws IOException {
verifyDataRows(response, rows(1, 20), rows(6, 30));
}

// TODO need fix below two unit tests
@Test
public void testStatsPercentile() throws IOException {
JSONObject response =
Expand All @@ -208,4 +207,13 @@ public void testStatsPercentileWithNull() throws IOException {
verifySchema(response, schema("percentile(balance, 50)", null, "long"));
verifyDataRows(response, rows(39225));
}

@Test
public void testStatsPercentileWithCompression() throws IOException {
JSONObject response =
executeQuery(
String.format("source=%s | stats percentile(balance, 50, 1)", TEST_INDEX_BANK));
verifySchema(response, schema("percentile(balance, 50, 1)", null, "long"));
verifyDataRows(response, rows(32838));
}
}
3 changes: 2 additions & 1 deletion ppl/src/main/antlr/OpenSearchPPLParser.g4
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,8 @@ takeAggFunction
;

percentileApproxFunction
: (PERCENTILE | PERCENTILE_APPROX) LT_PRTHS aggField = valueExpression COMMA quantile = numericLiteral RT_PRTHS
: (PERCENTILE | PERCENTILE_APPROX) LT_PRTHS aggField = valueExpression
COMMA quantile = numericLiteral (COMMA compression = numericLiteral)? RT_PRTHS
;

numericLiteral
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@
public class AstExpressionBuilder extends OpenSearchPPLParserBaseVisitor<UnresolvedExpression> {

private static final int DEFAULT_TAKE_FUNCTION_SIZE_VALUE = 10;
private static final double DEFAULT_PERCENTILE_COMPRESSION = 100.0;

/** The function name mapping between fronted and core engine. */
private static Map<String, String> FUNCTION_NAME_MAPPING =
Expand Down Expand Up @@ -163,6 +162,10 @@ public UnresolvedExpression visitPercentileApproxFunctionCall(
OpenSearchPPLParser.PercentileApproxFunctionCallContext ctx) {
ImmutableList.Builder<UnresolvedExpression> builder = ImmutableList.builder();
builder.add(new UnresolvedArgument("quantile", visit(ctx.percentileApproxFunction().quantile)));
if (ctx.percentileApproxFunction().compression != null) {
builder.add(
new UnresolvedArgument("compression", visit(ctx.percentileApproxFunction().compression)));
}
return new AggregateFunction(
"percentile", visit(ctx.percentileApproxFunction().aggField), builder.build());
}
Expand Down
Loading

0 comments on commit b0fbd8d

Please sign in to comment.