From d767868f42f012595003eb40be05056c1d5b4804 Mon Sep 17 00:00:00 2001 From: Lantao Jin <jinlantao@gmail.com> Date: Thu, 6 Jun 2024 11:02:20 +0800 Subject: [PATCH] Support Percentile in PPL * Support Percentile in PPL Signed-off-by: Lantao Jin <ltjin@amazon.com> * Remove ANSI SQL percentile syntax Signed-off-by: Lantao Jin <ltjin@amazon.com> * add more unit tests and increase test coverage Signed-off-by: Lantao Jin <ltjin@amazon.com> * increase test coverage Signed-off-by: Lantao Jin <ltjin@amazon.com> * address comments and add docs Signed-off-by: Lantao Jin <ltjin@amazon.com> * add examples in doc Signed-off-by: Lantao Jin <ltjin@amazon.com> * fix doctest failure and add more integ tests Signed-off-by: Lantao Jin <ltjin@amazon.com> * remove useless code and antlr4 files Signed-off-by: Lantao Jin <ltjin@amazon.com> --------- Signed-off-by: Lantao Jin <ltjin@amazon.com> --- core/build.gradle | 1 + .../org/opensearch/sql/expression/DSL.java | 12 + .../aggregation/AggregatorFunction.java | 43 +++ .../PercentileApproximateAggregator.java | 98 ++++++ .../function/BuiltinFunctionName.java | 4 + .../PercentileApproxAggregatorTest.java | 318 ++++++++++++++++++ .../optimizer/LogicalPlanOptimizerTest.java | 20 +- docs/user/dql/aggregations.rst | 19 ++ docs/user/ppl/cmd/stats.rst | 70 ++++ .../opensearch/sql/ppl/StatsCommandIT.java | 67 ++++ .../org/opensearch/sql/sql/AggregationIT.java | 40 ++- .../opensearch/sql/sql/WindowFunctionIT.java | 60 ++++ .../response/agg/PercentilesParser.java | 44 +++ .../response/agg/SinglePercentileParser.java | 40 +++ .../dsl/MetricAggregationBuilder.java | 39 ++- .../request/OpenSearchRequestBuilderTest.java | 20 ++ .../response/AggregationResponseUtils.java | 5 + ...enSearchAggregationResponseParserTest.java | 287 ++++++++++++++++ .../dsl/MetricAggregationBuilderTest.java | 90 +++++ ppl/src/main/antlr/OpenSearchPPLLexer.g4 | 1 + ppl/src/main/antlr/OpenSearchPPLParser.g4 | 13 +- .../sql/ppl/parser/AstExpressionBuilder.java | 40 +-- .../ppl/parser/AstExpressionBuilderTest.java | 36 +- sql/src/main/antlr/OpenSearchSQLLexer.g4 | 2 + sql/src/main/antlr/OpenSearchSQLParser.g4 | 14 +- .../sql/sql/parser/AstExpressionBuilder.java | 43 +-- .../sql/parser/AstExpressionBuilderTest.java | 20 ++ 27 files changed, 1375 insertions(+), 71 deletions(-) create mode 100644 core/src/main/java/org/opensearch/sql/expression/aggregation/PercentileApproximateAggregator.java create mode 100644 core/src/test/java/org/opensearch/sql/expression/aggregation/PercentileApproxAggregatorTest.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/PercentilesParser.java create mode 100644 opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SinglePercentileParser.java diff --git a/core/build.gradle b/core/build.gradle index a5fa4683ba..655e7d92c2 100644 --- a/core/build.gradle +++ b/core/build.gradle @@ -55,6 +55,7 @@ dependencies { api "com.fasterxml.jackson.core:jackson-databind:${versions.jackson_databind}" api "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}" api group: 'com.google.code.gson', name: 'gson', version: '2.8.9' + api group: 'com.tdunning', name: 't-digest', version: '3.3' api project(':common') testImplementation('org.junit.jupiter:junit-jupiter:5.9.3') diff --git a/core/src/main/java/org/opensearch/sql/expression/DSL.java b/core/src/main/java/org/opensearch/sql/expression/DSL.java index 12a7faafb2..9975afac7f 100644 --- a/core/src/main/java/org/opensearch/sql/expression/DSL.java +++ b/core/src/main/java/org/opensearch/sql/expression/DSL.java @@ -735,6 +735,18 @@ public static Aggregator max(Expression... expressions) { return aggregate(BuiltinFunctionName.MAX, expressions); } + /** + * OpenSearch uses T-Digest to approximate percentile, so PERCENTILE and PERCENTILE_APPROX are the + * same function. + */ + public static Aggregator percentile(Expression... expressions) { + return percentileApprox(expressions); + } + + public static Aggregator percentileApprox(Expression... expressions) { + return aggregate(BuiltinFunctionName.PERCENTILE_APPROX, expressions); + } + private static Aggregator aggregate(BuiltinFunctionName functionName, Expression... expressions) { return compile(FunctionProperties.None, functionName, expressions); } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java index bfc92d73c6..631eb2e613 100644 --- a/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/AggregatorFunction.java @@ -57,6 +57,7 @@ public static void register(BuiltinFunctionRepository repository) { repository.register(stddevSamp()); repository.register(stddevPop()); repository.register(take()); + repository.register(percentileApprox()); } private static DefaultFunctionResolver avg() { @@ -235,4 +236,46 @@ private static DefaultFunctionResolver take() { .build()); return functionResolver; } + + private static DefaultFunctionResolver percentileApprox() { + FunctionName functionName = BuiltinFunctionName.PERCENTILE_APPROX.getName(); + DefaultFunctionResolver functionResolver = + new DefaultFunctionResolver( + functionName, + new ImmutableMap.Builder<FunctionSignature, FunctionBuilder>() + .put( + new FunctionSignature(functionName, ImmutableList.of(INTEGER, DOUBLE)), + (functionProperties, arguments) -> + PercentileApproximateAggregator.percentileApprox(arguments, INTEGER)) + .put( + new FunctionSignature(functionName, ImmutableList.of(INTEGER, DOUBLE, DOUBLE)), + (functionProperties, arguments) -> + PercentileApproximateAggregator.percentileApprox(arguments, INTEGER)) + .put( + new FunctionSignature(functionName, ImmutableList.of(LONG, DOUBLE)), + (functionProperties, arguments) -> + PercentileApproximateAggregator.percentileApprox(arguments, LONG)) + .put( + new FunctionSignature(functionName, ImmutableList.of(LONG, DOUBLE, DOUBLE)), + (functionProperties, arguments) -> + PercentileApproximateAggregator.percentileApprox(arguments, LONG)) + .put( + new FunctionSignature(functionName, ImmutableList.of(FLOAT, DOUBLE)), + (functionProperties, arguments) -> + PercentileApproximateAggregator.percentileApprox(arguments, FLOAT)) + .put( + new FunctionSignature(functionName, ImmutableList.of(FLOAT, DOUBLE, DOUBLE)), + (functionProperties, arguments) -> + PercentileApproximateAggregator.percentileApprox(arguments, FLOAT)) + .put( + new FunctionSignature(functionName, ImmutableList.of(DOUBLE, DOUBLE)), + (functionProperties, arguments) -> + PercentileApproximateAggregator.percentileApprox(arguments, DOUBLE)) + .put( + new FunctionSignature(functionName, ImmutableList.of(DOUBLE, DOUBLE, DOUBLE)), + (functionProperties, arguments) -> + PercentileApproximateAggregator.percentileApprox(arguments, DOUBLE)) + .build()); + return functionResolver; + } } diff --git a/core/src/main/java/org/opensearch/sql/expression/aggregation/PercentileApproximateAggregator.java b/core/src/main/java/org/opensearch/sql/expression/aggregation/PercentileApproximateAggregator.java new file mode 100644 index 0000000000..8ec5df2d45 --- /dev/null +++ b/core/src/main/java/org/opensearch/sql/expression/aggregation/PercentileApproximateAggregator.java @@ -0,0 +1,98 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.expression.aggregation; + +import static org.opensearch.sql.data.model.ExprValueUtils.doubleValue; +import static org.opensearch.sql.utils.ExpressionUtils.format; + +import com.tdunning.math.stats.AVLTreeDigest; +import java.util.List; +import org.opensearch.sql.common.utils.StringUtils; +import org.opensearch.sql.data.model.ExprNullValue; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.function.BuiltinFunctionName; + +/** Aggregator to calculate approximate percentile. */ +public class PercentileApproximateAggregator + extends Aggregator<PercentileApproximateAggregator.PercentileApproximateState> { + + public static Aggregator percentileApprox(List<Expression> arguments, ExprCoreType returnType) { + return new PercentileApproximateAggregator(arguments, returnType); + } + + public PercentileApproximateAggregator(List<Expression> arguments, ExprCoreType returnType) { + super(BuiltinFunctionName.PERCENTILE_APPROX.getName(), arguments, returnType); + if (!ExprCoreType.numberTypes().contains(returnType)) { + throw new IllegalArgumentException( + String.format("percentile aggregation over %s type is not supported", returnType)); + } + } + + @Override + public PercentileApproximateState create() { + if (getArguments().size() == 2) { + return new PercentileApproximateState(getArguments().get(1).valueOf().doubleValue()); + } else { + return new PercentileApproximateState( + getArguments().get(1).valueOf().doubleValue(), + getArguments().get(2).valueOf().doubleValue()); + } + } + + @Override + protected PercentileApproximateState iterate(ExprValue value, PercentileApproximateState state) { + state.evaluate(value); + return state; + } + + @Override + public String toString() { + return StringUtils.format("%s(%s)", "percentile", format(getArguments())); + } + + /** + * PercentileApproximateState is used to store the AVLTreeDigest state for percentile estimation. + */ + protected static class PercentileApproximateState extends AVLTreeDigest + implements AggregationState { + // The compression level for the AVLTreeDigest, keep the same default value as OpenSearch core. + public static final double DEFAULT_COMPRESSION = 100.0; + private final double percent; + + PercentileApproximateState(double percent) { + super(DEFAULT_COMPRESSION); + if (percent < 0.0 || percent > 100.0) { + throw new IllegalArgumentException("out of bounds percent value, must be in [0, 100]"); + } + this.percent = percent / 100.0; + } + + /** + * Constructor for specifying both percent and compression level. + * + * @param percent the percent to compute, must be in [0, 100] + * @param compression the compression factor of the t-digest sketches used + */ + PercentileApproximateState(double percent, double compression) { + super(compression); + if (percent < 0.0 || percent > 100.0) { + throw new IllegalArgumentException("out of bounds percent value, must be in [0, 100]"); + } + this.percent = percent / 100.0; + } + + public void evaluate(ExprValue value) { + this.add(value.doubleValue()); + } + + @Override + public ExprValue result() { + return this.size() == 0 ? ExprNullValue.of() : doubleValue(this.quantile(percent)); + } + } +} diff --git a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java index f50fa927b8..fd5ea14a2e 100644 --- a/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java +++ b/core/src/main/java/org/opensearch/sql/expression/function/BuiltinFunctionName.java @@ -175,6 +175,8 @@ public enum BuiltinFunctionName { STDDEV_POP(FunctionName.of("stddev_pop")), // take top documents from aggregation bucket. TAKE(FunctionName.of("take")), + // t-digest percentile which is used in OpenSearch core by default. + PERCENTILE_APPROX(FunctionName.of("percentile_approx")), // Not always an aggregation query NESTED(FunctionName.of("nested")), @@ -279,6 +281,8 @@ public enum BuiltinFunctionName { .put("stddev_pop", BuiltinFunctionName.STDDEV_POP) .put("stddev_samp", BuiltinFunctionName.STDDEV_SAMP) .put("take", BuiltinFunctionName.TAKE) + .put("percentile", BuiltinFunctionName.PERCENTILE_APPROX) + .put("percentile_approx", BuiltinFunctionName.PERCENTILE_APPROX) .build(); public static Optional<BuiltinFunctionName> of(String str) { diff --git a/core/src/test/java/org/opensearch/sql/expression/aggregation/PercentileApproxAggregatorTest.java b/core/src/test/java/org/opensearch/sql/expression/aggregation/PercentileApproxAggregatorTest.java new file mode 100644 index 0000000000..ac617e7b32 --- /dev/null +++ b/core/src/test/java/org/opensearch/sql/expression/aggregation/PercentileApproxAggregatorTest.java @@ -0,0 +1,318 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.expression.aggregation; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.when; +import static org.opensearch.sql.data.model.ExprValueUtils.*; +import static org.opensearch.sql.data.type.ExprCoreType.*; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.data.model.ExprValue; +import org.opensearch.sql.data.model.ExprValueUtils; +import org.opensearch.sql.exception.ExpressionEvaluationException; +import org.opensearch.sql.expression.DSL; +import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.LiteralExpression; +import org.opensearch.sql.storage.bindingtuple.BindingTuple; + +@ExtendWith(MockitoExtension.class) +public class PercentileApproxAggregatorTest extends AggregationTest { + + @Mock Expression expression; + + @Mock ExprValue tupleValue; + + @Mock BindingTuple tuple; + + @Test + public void test_percentile_field_expression() { + ExprValue result = + aggregation(DSL.percentile(DSL.ref("integer_value", INTEGER), DSL.literal(50)), tuples); + assertEquals(3.0, result.value()); + result = aggregation(DSL.percentile(DSL.ref("long_value", LONG), DSL.literal(50)), tuples); + assertEquals(3.0, result.value()); + result = aggregation(DSL.percentile(DSL.ref("double_value", DOUBLE), DSL.literal(50)), tuples); + assertEquals(3.0, result.value()); + result = aggregation(DSL.percentile(DSL.ref("float_value", FLOAT), DSL.literal(50)), tuples); + assertEquals(3.0, result.value()); + } + + @Test + public void test_percentile_field_expression_with_user_defined_compression() { + ExprValue result = + aggregation( + DSL.percentile(DSL.ref("integer_value", INTEGER), DSL.literal(50), DSL.literal(0.1)), + tuples); + assertEquals(2.5, result.value()); + result = + aggregation( + DSL.percentile(DSL.ref("long_value", LONG), DSL.literal(50), DSL.literal(0.1)), tuples); + assertEquals(2.5, result.value()); + result = + aggregation( + DSL.percentile(DSL.ref("double_value", DOUBLE), DSL.literal(50), DSL.literal(0.1)), + tuples); + assertEquals(2.5, result.value()); + result = + aggregation( + DSL.percentile(DSL.ref("float_value", FLOAT), DSL.literal(50), DSL.literal(0.1)), + tuples); + assertEquals(2.5, result.value()); + } + + @Test + public void test_percentile_expression() { + ExprValue result = + percentile( + DSL.literal(50), + integerValue(0), + integerValue(1), + integerValue(2), + integerValue(3), + integerValue(4)); + assertEquals(2.0, result.value()); + result = percentile(DSL.literal(30), integerValue(2012), integerValue(2013)); + assertEquals(2012, result.integerValue()); + } + + @Test + public void test_percentile_with_negative() { + ExprValue result = + percentile( + DSL.literal(50), + longValue(-100000L), + longValue(-50000L), + longValue(40000L), + longValue(50000L)); + assertEquals(40000.0, result.value()); + ExprValue[] results = + percentiles(longValue(-100000L), longValue(-50000L), longValue(40000L), longValue(50000L)); + assertPercentileValues( + results, -100000.0, // p=1.0 + -100000.0, // p=5.0 + -100000.0, // p=10.0 + -100000.0, // p=20.0 + -50000.0, // p=25.0 + -50000.0, // p=30.0 + -50000.0, // p=40.0 + 40000.0, // p=50.0 + 40000.0, // p=60.0 + 40000.0, // p=70.0 + 50000.0, // p=75.0 + 50000.0, // p=80.0 + 50000.0, // p=90.0 + 50000.0, // p=95.0 + 50000.0, // p=99.0 + 50000.0, // p=99.9 + 50000.0); // p=100.0 + } + + @Test + public void test_percentile_value() { + ExprValue[] results = + percentiles( + integerValue(0), integerValue(1), integerValue(2), integerValue(3), integerValue(4)); + assertPercentileValues( + results, 0.0, // p=1.0 + 0.0, // p=5.0 + 0.0, // p=10.0 + 1.0, // p=20.0 + 1.0, // p=25.0 + 1.0, // p=30.0 + 2.0, // p=40.0 + 2.0, // p=50.0 + 3.0, // p=60.0 + 3.0, // p=70.0 + 3.0, // p=75.0 + 4.0, // p=80.0 + 4.0, // p=90.0 + 4.0, // p=95.0 + 4.0, // p=99.0 + 4.0, // p=99.9 + 4.0); // p=100.0 + } + + @Test + public void test_percentile_with_invalid_size() { + var exception = + assertThrows( + IllegalArgumentException.class, + () -> + aggregation( + DSL.percentile(DSL.ref("double_value", DOUBLE), DSL.literal(-1)), tuples)); + assertEquals("out of bounds percent value, must be in [0, 100]", exception.getMessage()); + exception = + assertThrows( + IllegalArgumentException.class, + () -> + aggregation( + DSL.percentile(DSL.ref("double_value", DOUBLE), DSL.literal(200)), tuples)); + assertEquals("out of bounds percent value, must be in [0, 100]", exception.getMessage()); + exception = + assertThrows( + IllegalArgumentException.class, + () -> + aggregation( + DSL.percentile( + DSL.ref("double_value", DOUBLE), DSL.literal(-1), DSL.literal(100)), + tuples)); + assertEquals("out of bounds percent value, must be in [0, 100]", exception.getMessage()); + exception = + assertThrows( + IllegalArgumentException.class, + () -> + aggregation( + DSL.percentile( + DSL.ref("double_value", DOUBLE), DSL.literal(200), DSL.literal(100)), + tuples)); + assertEquals("out of bounds percent value, must be in [0, 100]", exception.getMessage()); + var exception2 = + assertThrows( + ExpressionEvaluationException.class, + () -> + aggregation( + DSL.percentile(DSL.ref("double_value", DOUBLE), DSL.literal("string")), + tuples)); + assertEquals( + "percentile_approx function expected" + + " {[INTEGER,DOUBLE],[INTEGER,DOUBLE,DOUBLE],[LONG,DOUBLE],[LONG,DOUBLE,DOUBLE]," + + "[FLOAT,DOUBLE],[FLOAT,DOUBLE,DOUBLE],[DOUBLE,DOUBLE],[DOUBLE,DOUBLE,DOUBLE]}," + + " but get [DOUBLE,STRING]", + exception2.getMessage()); + } + + @Test + public void test_arithmetic_expression() { + ExprValue result = + aggregation( + DSL.percentile( + DSL.multiply( + DSL.ref("integer_value", INTEGER), + DSL.literal(ExprValueUtils.integerValue(10))), + DSL.literal(50)), + tuples); + assertEquals(30.0, result.value()); + } + + @Test + public void test_filtered_percentile() { + ExprValue result = + aggregation( + DSL.percentile(DSL.ref("integer_value", INTEGER), DSL.literal(50)) + .condition(DSL.greater(DSL.ref("integer_value", INTEGER), DSL.literal(1))), + tuples); + assertEquals(3.0, result.value()); + } + + @Test + public void test_with_missing() { + ExprValue result = + aggregation( + DSL.percentile(DSL.ref("integer_value", INTEGER), DSL.literal(50)), + tuples_with_null_and_missing); + assertEquals(2.0, result.value()); + } + + @Test + public void test_with_null() { + ExprValue result = + aggregation( + DSL.percentile(DSL.ref("double_value", DOUBLE), DSL.literal(50)), + tuples_with_null_and_missing); + assertEquals(4.0, result.value()); + } + + @Test + public void test_with_all_missing_or_null() { + ExprValue result = + aggregation( + DSL.percentile(DSL.ref("integer_value", INTEGER), DSL.literal(50)), + tuples_with_all_null_or_missing); + assertTrue(result.isNull()); + } + + @Test + public void test_unsupported_type() { + var exception = + assertThrows( + IllegalArgumentException.class, + () -> + new PercentileApproximateAggregator( + List.of(DSL.ref("string", STRING), DSL.ref("string", STRING)), STRING)); + assertEquals( + "percentile aggregation over STRING type is not supported", exception.getMessage()); + } + + @Test + public void test_to_string() { + Aggregator aggregator = DSL.percentile(DSL.ref("integer_value", INTEGER), DSL.literal(50)); + assertEquals("percentile(integer_value,50)", aggregator.toString()); + aggregator = + DSL.percentile(DSL.ref("integer_value", INTEGER), DSL.literal(50), DSL.literal(0.1)); + assertEquals("percentile(integer_value,50,0.1)", aggregator.toString()); + } + + private ExprValue[] percentiles(ExprValue value, ExprValue... values) { + return new ExprValue[] { + percentile(DSL.literal(1.0), value, values), + percentile(DSL.literal(5.0), value, values), + percentile(DSL.literal(10.0), value, values), + percentile(DSL.literal(20.0), value, values), + percentile(DSL.literal(25.0), value, values), + percentile(DSL.literal(30.0), value, values), + percentile(DSL.literal(40.0), value, values), + percentile(DSL.literal(50.0), value, values), + percentile(DSL.literal(60.0), value, values), + percentile(DSL.literal(70.0), value, values), + percentile(DSL.literal(75.0), value, values), + percentile(DSL.literal(80.0), value, values), + percentile(DSL.literal(90.0), value, values), + percentile(DSL.literal(95.0), value, values), + percentile(DSL.literal(99.0), value, values), + percentile(DSL.literal(99.9), value, values), + percentile(DSL.literal(100.0), value, values) + }; + } + + private void assertPercentileValues(ExprValue[] actualValues, Double... expectedValues) { + int i = 0; + for (Double expected : expectedValues) { + assertEquals(expected, actualValues[i].value()); + i++; + } + } + + private ExprValue percentile(LiteralExpression p, ExprValue value, ExprValue... values) { + when(expression.valueOf(any())).thenReturn(value, values); + when(expression.type()).thenReturn(DOUBLE); + return aggregation(DSL.percentile(expression, p), mockTuples(value, values)); + } + + private List<ExprValue> mockTuples(ExprValue value, ExprValue... values) { + List<ExprValue> mockTuples = new ArrayList<>(); + when(tupleValue.bindingTuples()).thenReturn(tuple); + mockTuples.add(tupleValue); + for (ExprValue exprValue : values) { + mockTuples.add(tupleValue); + } + return mockTuples; + } +} diff --git a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java index 2cdcb76e71..c25e415cfa 100644 --- a/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java +++ b/core/src/test/java/org/opensearch/sql/planner/optimizer/LogicalPlanOptimizerTest.java @@ -13,9 +13,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.sql.data.model.ExprValueUtils.integerValue; import static org.opensearch.sql.data.model.ExprValueUtils.longValue; -import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; -import static org.opensearch.sql.data.type.ExprCoreType.LONG; -import static org.opensearch.sql.data.type.ExprCoreType.STRING; +import static org.opensearch.sql.data.type.ExprCoreType.*; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.aggregation; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.filter; import static org.opensearch.sql.planner.logical.LogicalPlanDSL.highlight; @@ -180,6 +178,22 @@ void table_scan_builder_support_aggregation_push_down_can_apply_its_rule() { ImmutableList.of(DSL.named("longV", DSL.ref("longV", LONG)))))); } + @Test + void table_scan_builder_support_percentile_aggregation_push_down_can_apply_its_rule() { + when(tableScanBuilder.pushDownAggregation(any())).thenReturn(true); + + assertEquals( + tableScanBuilder, + optimize( + aggregation( + relation("schema", table), + ImmutableList.of( + DSL.named( + "PERCENTILE(intV, 1)", + DSL.percentile(DSL.ref("intV", INTEGER), DSL.ref("percentile", DOUBLE)))), + ImmutableList.of(DSL.named("longV", DSL.ref("longV", LONG)))))); + } + @Test void table_scan_builder_support_sort_push_down_can_apply_its_rule() { when(tableScanBuilder.pushDownSort(any())).thenReturn(true); diff --git a/docs/user/dql/aggregations.rst b/docs/user/dql/aggregations.rst index d0cbb28f62..42db4cdb4f 100644 --- a/docs/user/dql/aggregations.rst +++ b/docs/user/dql/aggregations.rst @@ -370,6 +370,25 @@ To get the count of distinct values of a field, you can add a keyword ``DISTINCT | 2 | 4 | +--------------------------+-----------------+ +PERCENTILE or PERCENTILE_APPROX +------------------------------- + +Description +>>>>>>>>>>> + +Usage: PERCENTILE(expr, percent) or PERCENTILE_APPROX(expr, percent). Returns the approximate percentile value of `expr` at the specified percentage. `percent` must be a constant between 0 and 100. + +Example:: + + os> SELECT gender, percentile(age, 90) as p90 FROM accounts GROUP BY gender; + fetched rows / total rows = 2/2 + +----------+-------+ + | gender | p90 | + |----------+-------| + | F | 28 | + | M | 36 | + +----------+-------+ + HAVING Clause ============= diff --git a/docs/user/ppl/cmd/stats.rst b/docs/user/ppl/cmd/stats.rst index d9cca9e314..096d3eacfc 100644 --- a/docs/user/ppl/cmd/stats.rst +++ b/docs/user/ppl/cmd/stats.rst @@ -259,6 +259,27 @@ Example:: | [Amber,Hattie,Nanette,Dale] | +-----------------------------+ +PERCENTILE or PERCENTILE_APPROX +------------------------------- + +Description +>>>>>>>>>>> + +Usage: PERCENTILE(expr, percent) or PERCENTILE_APPROX(expr, percent). Return the approximate percentile value of expr at the specified percentage. + +* percent: The number must be a constant between 0 and 100. + +Example:: + + os> source=accounts | stats percentile(age, 90) by gender; + fetched rows / total rows = 2/2 + +-----------------------+----------+ + | percentile(age, 90) | gender | + |-----------------------+----------| + | 28 | F | + | 36 | M | + +-----------------------+----------+ + Example 1: Calculate the count of events ======================================== @@ -419,3 +440,52 @@ PPL query:: | 2 | [amberduke@pyrami.com,daleadams@boink.com] | 30 | M | | 1 | [hattiebond@netagy.com] | 35 | M | +-------+--------------------------------------------+------------+----------+ + +Example 11: Calculate the percentile of a field +=============================================== + +The example show calculate the percentile 90th age of all the accounts. + +PPL query:: + + os> source=accounts | stats percentile(age, 90); + fetched rows / total rows = 1/1 + +-----------------------+ + | percentile(age, 90) | + |-----------------------| + | 36 | + +-----------------------+ + + +Example 12: Calculate the percentile of a field by group +======================================================== + +The example show calculate the percentile 90th age of all the accounts group by gender. + +PPL query:: + + os> source=accounts | stats percentile(age, 90) by gender; + fetched rows / total rows = 2/2 + +-----------------------+----------+ + | percentile(age, 90) | gender | + |-----------------------+----------| + | 28 | F | + | 36 | M | + +-----------------------+----------+ + +Example 13: Calculate the percentile by a gender and span +========================================================= + +The example gets the percentile 90th age by the interval of 10 years and group by gender. + +PPL query:: + + os> source=accounts | stats percentile(age, 90) as p90 by span(age, 10) as age_span, gender + fetched rows / total rows = 2/2 + +-------+------------+----------+ + | p90 | age_span | gender | + |-------+------------+----------| + | 28 | 20 | F | + | 36 | 30 | M | + +-------+------------+----------+ + diff --git a/integ-test/src/test/java/org/opensearch/sql/ppl/StatsCommandIT.java b/integ-test/src/test/java/org/opensearch/sql/ppl/StatsCommandIT.java index 92b9e309b8..40acd2f093 100644 --- a/integ-test/src/test/java/org/opensearch/sql/ppl/StatsCommandIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/ppl/StatsCommandIT.java @@ -189,4 +189,71 @@ public void testStatsAliasedSpan() throws IOException { response, schema("count()", null, "integer"), schema("age_bucket", null, "integer")); verifyDataRows(response, rows(1, 20), rows(6, 30)); } + + @Test + public void testStatsPercentile() throws IOException { + JSONObject response = + executeQuery(String.format("source=%s | stats percentile(balance, 50)", TEST_INDEX_BANK)); + verifySchema(response, schema("percentile(balance, 50)", null, "long")); + verifyDataRows(response, rows(32838)); + } + + @Test + public void testStatsPercentileWithNull() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats percentile(balance, 50)", TEST_INDEX_BANK_WITH_NULL_VALUES)); + verifySchema(response, schema("percentile(balance, 50)", null, "long")); + verifyDataRows(response, rows(39225)); + } + + @Test + public void testStatsPercentileWithCompression() throws IOException { + JSONObject response = + executeQuery( + String.format("source=%s | stats percentile(balance, 50, 1)", TEST_INDEX_BANK)); + verifySchema(response, schema("percentile(balance, 50, 1)", null, "long")); + verifyDataRows(response, rows(32838)); + } + + @Test + public void testStatsPercentileWhere() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats percentile(balance, 50) as p50 by state | where p50 > 40000", + TEST_INDEX_BANK)); + verifySchema(response, schema("p50", null, "long"), schema("state", null, "string")); + verifyDataRows(response, rows(48086, "IN"), rows(40540, "PA")); + } + + @Test + public void testStatsPercentileByNullValue() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats percentile(balance, 50) as p50 by age", + TEST_INDEX_BANK_WITH_NULL_VALUES)); + verifySchema(response, schema("p50", null, "long"), schema("age", null, "integer")); + verifyDataRows( + response, + rows(0, null), + rows(32838, 28), + rows(39225, 32), + rows(4180, 33), + rows(48086, 34), + rows(0, 36)); + } + + @Test + public void testStatsPercentileBySpan() throws IOException { + JSONObject response = + executeQuery( + String.format( + "source=%s | stats percentile(balance, 50) as p50 by span(age, 10) as age_bucket", + TEST_INDEX_BANK)); + verifySchema(response, schema("p50", null, "long"), schema("age_bucket", null, "integer")); + verifyDataRows(response, rows(32838, 20), rows(39225, 30)); + } } diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java index 3f71499f97..29358bd1c3 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/AggregationIT.java @@ -5,9 +5,7 @@ package org.opensearch.sql.sql; -import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_BANK; -import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_CALCS; -import static org.opensearch.sql.legacy.TestsConstants.TEST_INDEX_NULL_MISSING; +import static org.opensearch.sql.legacy.TestsConstants.*; import static org.opensearch.sql.legacy.plugin.RestSqlAction.QUERY_API_ENDPOINT; import static org.opensearch.sql.util.MatcherUtils.rows; import static org.opensearch.sql.util.MatcherUtils.schema; @@ -713,6 +711,42 @@ public void testAvgTimeStampInMemory() throws IOException { verifySome(response.getJSONArray("datarows"), rows("2004-07-20 10:38:09.705")); } + @Test + public void testPercentilePushedDown() throws IOException { + var response = + executeQuery(String.format("SELECT percentile(balance, 50)" + " FROM %s", TEST_INDEX_BANK)); + verifySchema(response, schema("percentile(balance, 50)", null, "long")); + verifyDataRows(response, rows(32838)); + } + + @Test + public void testFilteredPercentilePushDown() throws IOException { + JSONObject response = + executeQuery( + "SELECT percentile(balance, 50) FILTER(WHERE balance > 40000) FROM " + TEST_INDEX_BANK); + verifySchema( + response, schema("percentile(balance, 50) FILTER(WHERE balance > 40000)", null, "long")); + verifyDataRows(response, rows(48086)); + } + + @Test + public void testPercentileGroupByPushDown() throws IOException { + var response = + executeQuery( + String.format( + "SELECT percentile(balance, 50), age" + " FROM %s GROUP BY age", TEST_INDEX_BANK)); + verifySchema( + response, schema("percentile(balance, 50)", null, "long"), schema("age", null, "integer")); + verifyDataRows( + response, + rows(32838, 28), + rows(39225, 32), + rows(4180, 33), + rows(48086, 34), + rows(16418, 36), + rows(40540, 39)); + } + protected JSONObject executeQuery(String query) throws IOException { Request request = new Request("POST", QUERY_API_ENDPOINT); request.setJsonEntity(String.format(Locale.ROOT, "{\n" + " \"query\": \"%s\"\n" + "}", query)); diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/WindowFunctionIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/WindowFunctionIT.java index 86257e6a22..95c1f7433d 100644 --- a/integ-test/src/test/java/org/opensearch/sql/sql/WindowFunctionIT.java +++ b/integ-test/src/test/java/org/opensearch/sql/sql/WindowFunctionIT.java @@ -123,4 +123,64 @@ public void testDistinctCountPartition() { rows("Duke Willmington", 1), rows("Ratliff", 1)); } + + @Test + public void testPercentileOverNull() { + JSONObject response = + new JSONObject( + executeQuery( + "SELECT lastname, percentile(balance, 50) OVER() " + + "FROM " + + TestsConstants.TEST_INDEX_BANK, + "jdbc")); + verifyDataRows( + response, + rows("Duke Willmington", 32838), + rows("Bond", 32838), + rows("Bates", 32838), + rows("Adams", 32838), + rows("Ratliff", 32838), + rows("Ayala", 32838), + rows("Mcpherson", 32838)); + } + + @Test + public void testPercentileOver() { + JSONObject response = + new JSONObject( + executeQuery( + "SELECT lastname, percentile(balance, 50) OVER(ORDER BY lastname) " + + "FROM " + + TestsConstants.TEST_INDEX_BANK, + "jdbc")); + verifyDataRowsInOrder( + response, + rows("Adams", 4180), + rows("Ayala", 40540), + rows("Bates", 32838), + rows("Bond", 32838), + rows("Duke Willmington", 32838), + rows("Mcpherson", 39225), + rows("Ratliff", 32838)); + } + + @Test + public void testPercentilePartition() { + JSONObject response = + new JSONObject( + executeQuery( + "SELECT lastname, percentile(balance, 50) OVER(PARTITION BY gender ORDER BY" + + " lastname) FROM " + + TestsConstants.TEST_INDEX_BANK, + "jdbc")); + verifyDataRowsInOrder( + response, + rows("Ayala", 40540), + rows("Bates", 40540), + rows("Mcpherson", 40540), + rows("Adams", 4180), + rows("Bond", 5686), + rows("Duke Willmington", 5686), + rows("Ratliff", 16418)); + } } diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/PercentilesParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/PercentilesParser.java new file mode 100644 index 0000000000..86ed735b4a --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/PercentilesParser.java @@ -0,0 +1,44 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import com.google.common.collect.Streams; +import java.util.Collections; +import java.util.Map; +import java.util.stream.Collectors; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.metrics.Percentile; +import org.opensearch.search.aggregations.metrics.Percentiles; + +@EqualsAndHashCode +@RequiredArgsConstructor +public class PercentilesParser implements MetricParser { + + @Getter private final String name; + + @Override + public Map<String, Object> parse(Aggregation agg) { + return Collections.singletonMap( + agg.getName(), + // TODO a better implementation here is providing a class `MultiValueParser` + // similar to `SingleValueParser`. However, there is no method `values()` available + // in `org.opensearch.search.aggregations.metrics.MultiValue`. + Streams.stream(((Percentiles) agg).iterator()) + .map(Percentile::getValue) + .collect(Collectors.toList())); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SinglePercentileParser.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SinglePercentileParser.java new file mode 100644 index 0000000000..94a70302af --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/response/agg/SinglePercentileParser.java @@ -0,0 +1,40 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"). + * You may not use this file except in compliance with the License. + * A copy of the License is located at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * or in the "license" file accompanying this file. This file is distributed + * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either + * express or implied. See the License for the specific language governing + * permissions and limitations under the License. + */ + +package org.opensearch.sql.opensearch.response.agg; + +import com.google.common.collect.Streams; +import java.util.Collections; +import java.util.Map; +import lombok.EqualsAndHashCode; +import lombok.Getter; +import lombok.RequiredArgsConstructor; +import org.opensearch.search.aggregations.Aggregation; +import org.opensearch.search.aggregations.metrics.Percentiles; + +@EqualsAndHashCode +@RequiredArgsConstructor +public class SinglePercentileParser implements MetricParser { + + @Getter private final String name; + + @Override + public Map<String, Object> parse(Aggregation agg) { + return Collections.singletonMap( + agg.getName(), + // TODO `Percentiles` implements interface + // `org.opensearch.search.aggregations.metrics.MultiValue`, but there is not + // method `values()` available in this interface. So we + Streams.stream(((Percentiles) agg).iterator()).findFirst().get().getValue()); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java index c99fbfdc49..779fe2f1c9 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilder.java @@ -17,6 +17,7 @@ import org.opensearch.search.aggregations.bucket.filter.FilterAggregationBuilder; import org.opensearch.search.aggregations.metrics.CardinalityAggregationBuilder; import org.opensearch.search.aggregations.metrics.ExtendedStats; +import org.opensearch.search.aggregations.metrics.PercentilesAggregationBuilder; import org.opensearch.search.aggregations.metrics.TopHitsAggregationBuilder; import org.opensearch.search.aggregations.support.ValuesSourceAggregationBuilder; import org.opensearch.sql.expression.Expression; @@ -24,11 +25,7 @@ import org.opensearch.sql.expression.LiteralExpression; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.expression.aggregation.NamedAggregator; -import org.opensearch.sql.opensearch.response.agg.FilterParser; -import org.opensearch.sql.opensearch.response.agg.MetricParser; -import org.opensearch.sql.opensearch.response.agg.SingleValueParser; -import org.opensearch.sql.opensearch.response.agg.StatsParser; -import org.opensearch.sql.opensearch.response.agg.TopHitsParser; +import org.opensearch.sql.opensearch.response.agg.*; import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; import org.opensearch.sql.opensearch.storage.serialization.ExpressionSerializer; @@ -160,6 +157,16 @@ public Pair<AggregationBuilder, MetricParser> visitNamedAggregator( condition, name, new TopHitsParser(name)); + case "percentile": + case "percentile_approx": + return make( + AggregationBuilders.percentiles(name), + expression, + node.getArguments().get(1), // percent + node.getArguments().size() >= 3 ? node.getArguments().get(2) : null, // compression + condition, + name, + new SinglePercentileParser(name)); default: throw new IllegalStateException( String.format("unsupported aggregator %s", node.getFunctionName().getFunctionName())); @@ -219,6 +226,28 @@ private Pair<AggregationBuilder, MetricParser> make( return Pair.of(builder, parser); } + private Pair<AggregationBuilder, MetricParser> make( + PercentilesAggregationBuilder builder, + Expression expression, + Expression percent, + Expression compression, + Expression condition, + String name, + MetricParser parser) { + PercentilesAggregationBuilder aggregationBuilder = + helper.build(expression, builder::field, builder::script); + if (compression != null) { + aggregationBuilder.compression(compression.valueOf().doubleValue()); + } + aggregationBuilder.percentiles(percent.valueOf().doubleValue()); + if (condition != null) { + return Pair.of( + makeFilterAggregation(aggregationBuilder, condition, name), + FilterParser.builder().name(name).metricsParser(parser).build()); + } + return Pair.of(aggregationBuilder, parser); + } + /** * Replace star or literal with OpenSearch metadata field "_index". Because: 1) Analyzer already * converts * to string literal, literal check here can handle both COUNT(*) and COUNT(1). 2) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java index 5bb0a2207b..742e76cbd0 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/request/OpenSearchRequestBuilderTest.java @@ -58,6 +58,7 @@ import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.response.agg.CompositeAggregationParser; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; +import org.opensearch.sql.opensearch.response.agg.SinglePercentileParser; import org.opensearch.sql.opensearch.response.agg.SingleValueParser; import org.opensearch.sql.planner.logical.LogicalNested; @@ -165,6 +166,25 @@ void test_push_down_aggregation() { verify(exprValueFactory).setParser(responseParser); } + @Test + void test_push_down_percentile_aggregation() { + AggregationBuilder aggBuilder = + AggregationBuilders.composite( + "composite_buckets", Collections.singletonList(new TermsValuesSourceBuilder("longA"))); + OpenSearchAggregationResponseParser responseParser = + new CompositeAggregationParser(new SinglePercentileParser("PERCENTILE(intA, 50)")); + requestBuilder.pushDownAggregation(Pair.of(List.of(aggBuilder), responseParser)); + + assertEquals( + new SearchSourceBuilder() + .from(DEFAULT_OFFSET) + .size(0) + .timeout(DEFAULT_QUERY_TIMEOUT) + .aggregation(aggBuilder), + requestBuilder.getSourceBuilder()); + verify(exprValueFactory).setParser(responseParser); + } + @Test void test_push_down_query_and_sort() { QueryBuilder query = QueryBuilders.termQuery("intA", 1); diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/AggregationResponseUtils.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/AggregationResponseUtils.java index 76148b9395..ccdfdce7a4 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/AggregationResponseUtils.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/AggregationResponseUtils.java @@ -41,8 +41,10 @@ import org.opensearch.search.aggregations.metrics.ParsedMax; import org.opensearch.search.aggregations.metrics.ParsedMin; import org.opensearch.search.aggregations.metrics.ParsedSum; +import org.opensearch.search.aggregations.metrics.ParsedTDigestPercentiles; import org.opensearch.search.aggregations.metrics.ParsedTopHits; import org.opensearch.search.aggregations.metrics.ParsedValueCount; +import org.opensearch.search.aggregations.metrics.PercentilesAggregationBuilder; import org.opensearch.search.aggregations.metrics.SumAggregationBuilder; import org.opensearch.search.aggregations.metrics.TopHitsAggregationBuilder; import org.opensearch.search.aggregations.metrics.ValueCountAggregationBuilder; @@ -56,6 +58,9 @@ public class AggregationResponseUtils { .put(MaxAggregationBuilder.NAME, (p, c) -> ParsedMax.fromXContent(p, (String) c)) .put(SumAggregationBuilder.NAME, (p, c) -> ParsedSum.fromXContent(p, (String) c)) .put(AvgAggregationBuilder.NAME, (p, c) -> ParsedAvg.fromXContent(p, (String) c)) + .put( + PercentilesAggregationBuilder.NAME, + (p, c) -> ParsedTDigestPercentiles.fromXContent(p, (String) c)) .put( ExtendedStatsAggregationBuilder.NAME, (p, c) -> ParsedExtendedStats.fromXContent(p, (String) c)) diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParserTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParserTest.java index 1a15e57c55..9ae76f8843 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParserTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/response/OpenSearchAggregationResponseParserTest.java @@ -26,6 +26,8 @@ import org.opensearch.sql.opensearch.response.agg.FilterParser; import org.opensearch.sql.opensearch.response.agg.NoBucketAggregationParser; import org.opensearch.sql.opensearch.response.agg.OpenSearchAggregationResponseParser; +import org.opensearch.sql.opensearch.response.agg.PercentilesParser; +import org.opensearch.sql.opensearch.response.agg.SinglePercentileParser; import org.opensearch.sql.opensearch.response.agg.SingleValueParser; import org.opensearch.sql.opensearch.response.agg.StatsParser; import org.opensearch.sql.opensearch.response.agg.TopHitsParser; @@ -309,6 +311,291 @@ void top_hits_aggregation_should_pass() { contains(ImmutableMap.of("type", "take", "take", ImmutableList.of("m", "f")))); } + /** SELECT PERCENTILE(age, 50) FROM accounts. */ + @Test + void no_bucket_one_metric_percentile_should_pass() { + String response = + "{\n" + + " \"percentiles#percentile\": {\n" + + " \"values\": {\n" + + " \"50.0\": 35.0\n" + + " }\n" + + " }\n" + + " }"; + NoBucketAggregationParser parser = + new NoBucketAggregationParser(new SinglePercentileParser("percentile")); + assertThat(parse(parser, response), contains(entry("percentile", 35.0))); + } + + /** SELECT PERCENTILE(age, 50), MAX(age) FROM accounts. */ + @Test + void no_bucket_two_metric_percentile_should_pass() { + String response = + "{\n" + + " \"percentiles#percentile\": {\n" + + " \"values\": {\n" + + " \"50.0\": 35.0\n" + + " }\n" + + " },\n" + + " \"max#max\": {\n" + + " \"value\": 40\n" + + " }\n" + + " }"; + NoBucketAggregationParser parser = + new NoBucketAggregationParser( + new SinglePercentileParser("percentile"), new SingleValueParser("max")); + assertThat(parse(parser, response), contains(entry("percentile", 35.0, "max", 40.0))); + } + + /** SELECT PERCENTILE(age, 50) FROM accounts GROUP BY type. */ + @Test + void one_bucket_one_metric_percentile_should_pass() { + String response = + "{\n" + + " \"composite#composite_buckets\": {\n" + + " \"after_key\": {\n" + + " \"type\": \"sale\"\n" + + " },\n" + + " \"buckets\": [\n" + + " {\n" + + " \"key\": {\n" + + " \"type\": \"cost\"\n" + + " },\n" + + " \"doc_count\": 2,\n" + + " \"percentiles#percentile\": {\n" + + " \"values\": {\n" + + " \"50.0\": 40.0\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"key\": {\n" + + " \"type\": \"sale\"\n" + + " },\n" + + " \"doc_count\": 2,\n" + + " \"percentiles#percentile\": {\n" + + " \"values\": {\n" + + " \"50.0\": 100.0\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + + OpenSearchAggregationResponseParser parser = + new CompositeAggregationParser(new SinglePercentileParser("percentile")); + assertThat( + parse(parser, response), + containsInAnyOrder( + ImmutableMap.of("type", "cost", "percentile", 40d), + ImmutableMap.of("type", "sale", "percentile", 100d))); + } + + /** SELECT PERCENTILE(age, 50) FROM accounts GROUP BY type, region. */ + @Test + void two_bucket_one_metric_percentile_should_pass() { + String response = + "{\n" + + " \"composite#composite_buckets\": {\n" + + " \"after_key\": {\n" + + " \"type\": \"sale\",\n" + + " \"region\": \"us\"\n" + + " },\n" + + " \"buckets\": [\n" + + " {\n" + + " \"key\": {\n" + + " \"type\": \"cost\",\n" + + " \"region\": \"us\"\n" + + " },\n" + + " \"doc_count\": 2,\n" + + " \"percentiles#percentile\": {\n" + + " \"values\": {\n" + + " \"50.0\": 40.0\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"key\": {\n" + + " \"type\": \"sale\",\n" + + " \"region\": \"uk\"\n" + + " },\n" + + " \"doc_count\": 2,\n" + + " \"percentiles#percentile\": {\n" + + " \"values\": {\n" + + " \"50.0\": 100.0\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + + OpenSearchAggregationResponseParser parser = + new CompositeAggregationParser( + new SinglePercentileParser("percentile"), new SingleValueParser("max")); + assertThat( + parse(parser, response), + containsInAnyOrder( + ImmutableMap.of("type", "cost", "region", "us", "percentile", 40d), + ImmutableMap.of("type", "sale", "region", "uk", "percentile", 100d))); + } + + /** SELECT PERCENTILES(age) FROM accounts. */ + @Test + void no_bucket_percentiles_should_pass() { + String response = + "{\n" + + " \"percentiles#percentiles\": {\n" + + " \"values\": {\n" + + " \"1.0\": 21.0,\n" + + " \"5.0\": 27.0,\n" + + " \"25.0\": 30.0,\n" + + " \"50.0\": 35.0,\n" + + " \"75.0\": 55.0,\n" + + " \"95.0\": 58.0,\n" + + " \"99.0\": 60.0\n" + + " }\n" + + " }\n" + + " }"; + NoBucketAggregationParser parser = + new NoBucketAggregationParser(new PercentilesParser("percentiles")); + assertThat( + parse(parser, response), + contains(entry("percentiles", List.of(21.0, 27.0, 30.0, 35.0, 55.0, 58.0, 60.0)))); + } + + /** SELECT PERCENTILES(age) FROM accounts GROUP BY type. */ + @Test + void one_bucket_percentiles_should_pass() { + String response = + "{\n" + + " \"composite#composite_buckets\": {\n" + + " \"after_key\": {\n" + + " \"type\": \"sale\"\n" + + " },\n" + + " \"buckets\": [\n" + + " {\n" + + " \"key\": {\n" + + " \"type\": \"cost\"\n" + + " },\n" + + " \"doc_count\": 2,\n" + + " \"percentiles#percentiles\": {\n" + + " \"values\": {\n" + + " \"1.0\": 21.0,\n" + + " \"5.0\": 27.0,\n" + + " \"25.0\": 30.0,\n" + + " \"50.0\": 35.0,\n" + + " \"75.0\": 55.0,\n" + + " \"95.0\": 58.0,\n" + + " \"99.0\": 60.0\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"key\": {\n" + + " \"type\": \"sale\"\n" + + " },\n" + + " \"doc_count\": 2,\n" + + " \"percentiles#percentiles\": {\n" + + " \"values\": {\n" + + " \"1.0\": 21.0,\n" + + " \"5.0\": 27.0,\n" + + " \"25.0\": 30.0,\n" + + " \"50.0\": 35.0,\n" + + " \"75.0\": 55.0,\n" + + " \"95.0\": 58.0,\n" + + " \"99.0\": 60.0\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + + OpenSearchAggregationResponseParser parser = + new CompositeAggregationParser(new PercentilesParser("percentiles")); + assertThat( + parse(parser, response), + containsInAnyOrder( + ImmutableMap.of( + "type", "cost", "percentiles", List.of(21.0, 27.0, 30.0, 35.0, 55.0, 58.0, 60.0)), + ImmutableMap.of( + "type", "sale", "percentiles", List.of(21.0, 27.0, 30.0, 35.0, 55.0, 58.0, 60.0)))); + } + + /** SELECT PERCENTILES(age) FROM accounts GROUP BY type, region. */ + @Test + void two_bucket_percentiles_should_pass() { + String response = + "{\n" + + " \"composite#composite_buckets\": {\n" + + " \"after_key\": {\n" + + " \"type\": \"sale\",\n" + + " \"region\": \"us\"\n" + + " },\n" + + " \"buckets\": [\n" + + " {\n" + + " \"key\": {\n" + + " \"type\": \"cost\",\n" + + " \"region\": \"us\"\n" + + " },\n" + + " \"doc_count\": 2,\n" + + " \"percentiles#percentiles\": {\n" + + " \"values\": {\n" + + " \"1.0\": 21.0,\n" + + " \"5.0\": 27.0,\n" + + " \"25.0\": 30.0,\n" + + " \"50.0\": 35.0,\n" + + " \"75.0\": 55.0,\n" + + " \"95.0\": 58.0,\n" + + " \"99.0\": 60.0\n" + + " }\n" + + " }\n" + + " },\n" + + " {\n" + + " \"key\": {\n" + + " \"type\": \"sale\",\n" + + " \"region\": \"uk\"\n" + + " },\n" + + " \"doc_count\": 2,\n" + + " \"percentiles#percentiles\": {\n" + + " \"values\": {\n" + + " \"1.0\": 21.0,\n" + + " \"5.0\": 27.0,\n" + + " \"25.0\": 30.0,\n" + + " \"50.0\": 35.0,\n" + + " \"75.0\": 55.0,\n" + + " \"95.0\": 58.0,\n" + + " \"99.0\": 60.0\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + + OpenSearchAggregationResponseParser parser = + new CompositeAggregationParser(new PercentilesParser("percentiles")); + assertThat( + parse(parser, response), + containsInAnyOrder( + ImmutableMap.of( + "type", + "cost", + "region", + "us", + "percentiles", + List.of(21.0, 27.0, 30.0, 35.0, 55.0, 58.0, 60.0)), + ImmutableMap.of( + "type", + "sale", + "region", + "uk", + "percentiles", + List.of(21.0, 27.0, 30.0, 35.0, 55.0, 58.0, 60.0)))); + } + public List<Map<String, Object>> parse(OpenSearchAggregationResponseParser parser, String json) { return parser.parse(fromJson(json)); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java index 7f302c9c53..6d792dec25 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/script/aggregation/dsl/MetricAggregationBuilderTest.java @@ -10,6 +10,7 @@ import static org.mockito.Mockito.when; import static org.opensearch.sql.common.utils.StringUtils.format; import static org.opensearch.sql.data.type.ExprCoreType.ARRAY; +import static org.opensearch.sql.data.type.ExprCoreType.DOUBLE; import static org.opensearch.sql.data.type.ExprCoreType.INTEGER; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import static org.opensearch.sql.expression.DSL.literal; @@ -39,6 +40,7 @@ import org.opensearch.sql.expression.aggregation.MaxAggregator; import org.opensearch.sql.expression.aggregation.MinAggregator; import org.opensearch.sql.expression.aggregation.NamedAggregator; +import org.opensearch.sql.expression.aggregation.PercentileApproximateAggregator; import org.opensearch.sql.expression.aggregation.SumAggregator; import org.opensearch.sql.expression.aggregation.TakeAggregator; import org.opensearch.sql.expression.function.FunctionName; @@ -215,6 +217,94 @@ void should_build_varSamp_aggregation() { varianceSample(Arrays.asList(ref("age", INTEGER)), INTEGER))))); } + @Test + void should_build_percentile_aggregation() { + assertEquals( + format( + "{%n" + + " \"percentile(age, 50)\" : {%n" + + " \"percentiles\" : {%n" + + " \"field\" : \"age\",%n" + + " \"percents\" : [ 50.0 ],%n" + + " \"keyed\" : true,%n" + + " \"tdigest\" : {%n" + + " \"compression\" : 100.0%n" + + " }%n" + + " }%n" + + " }%n" + + "}"), + buildQuery( + Arrays.asList( + named( + "percentile(age, 50)", + new PercentileApproximateAggregator( + Arrays.asList(ref("age", INTEGER), literal(50)), DOUBLE))))); + } + + @Test + void should_build_percentile_with_compression_aggregation() { + assertEquals( + format( + "{%n" + + " \"percentile(age, 50)\" : {%n" + + " \"percentiles\" : {%n" + + " \"field\" : \"age\",%n" + + " \"percents\" : [ 50.0 ],%n" + + " \"keyed\" : true,%n" + + " \"tdigest\" : {%n" + + " \"compression\" : 0.1%n" + + " }%n" + + " }%n" + + " }%n" + + "}"), + buildQuery( + Arrays.asList( + named( + "percentile(age, 50)", + new PercentileApproximateAggregator( + Arrays.asList(ref("age", INTEGER), literal(50), literal(0.1)), DOUBLE))))); + } + + @Test + void should_build_filtered_percentile_aggregation() { + assertEquals( + format( + "{%n" + + " \"percentile(age, 50)\" : {%n" + + " \"filter\" : {%n" + + " \"range\" : {%n" + + " \"age\" : {%n" + + " \"from\" : 30,%n" + + " \"to\" : null,%n" + + " \"include_lower\" : false,%n" + + " \"include_upper\" : true,%n" + + " \"boost\" : 1.0%n" + + " }%n" + + " }%n" + + " },%n" + + " \"aggregations\" : {%n" + + " \"percentile(age, 50)\" : {%n" + + " \"percentiles\" : {%n" + + " \"field\" : \"age\",%n" + + " \"percents\" : [ 50.0 ],%n" + + " \"keyed\" : true,%n" + + " \"tdigest\" : {%n" + + " \"compression\" : 100.0%n" + + " }%n" + + " }%n" + + " }%n" + + " }%n" + + " }%n" + + "}"), + buildQuery( + Arrays.asList( + named( + "percentile(age, 50)", + new PercentileApproximateAggregator( + Arrays.asList(ref("age", INTEGER), literal(50)), DOUBLE) + .condition(DSL.greater(ref("age", INTEGER), literal(30))))))); + } + @Test void should_build_stddevPop_aggregation() { assertEquals( diff --git a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 index e74aed30eb..9f707c13cd 100644 --- a/ppl/src/main/antlr/OpenSearchPPLLexer.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLLexer.g4 @@ -188,6 +188,7 @@ VAR_POP: 'VAR_POP'; STDDEV_SAMP: 'STDDEV_SAMP'; STDDEV_POP: 'STDDEV_POP'; PERCENTILE: 'PERCENTILE'; +PERCENTILE_APPROX: 'PERCENTILE_APPROX'; TAKE: 'TAKE'; FIRST: 'FIRST'; LAST: 'LAST'; diff --git a/ppl/src/main/antlr/OpenSearchPPLParser.g4 b/ppl/src/main/antlr/OpenSearchPPLParser.g4 index 21cfc65aa1..5a9c179d1a 100644 --- a/ppl/src/main/antlr/OpenSearchPPLParser.g4 +++ b/ppl/src/main/antlr/OpenSearchPPLParser.g4 @@ -216,8 +216,8 @@ statsFunction : statsFunctionName LT_PRTHS valueExpression RT_PRTHS # statsFunctionCall | COUNT LT_PRTHS RT_PRTHS # countAllFunctionCall | (DISTINCT_COUNT | DC) LT_PRTHS valueExpression RT_PRTHS # distinctCountFunctionCall - | percentileAggFunction # percentileAggFunctionCall | takeAggFunction # takeAggFunctionCall + | percentileApproxFunction # percentileApproxFunctionCall ; statsFunctionName @@ -230,16 +230,23 @@ statsFunctionName | VAR_POP | STDDEV_SAMP | STDDEV_POP + | PERCENTILE ; takeAggFunction : TAKE LT_PRTHS fieldExpression (COMMA size = integerLiteral)? RT_PRTHS ; -percentileAggFunction - : PERCENTILE LESS value = integerLiteral GREATER LT_PRTHS aggField = fieldExpression RT_PRTHS +percentileApproxFunction + : (PERCENTILE | PERCENTILE_APPROX) LT_PRTHS aggField = valueExpression + COMMA percent = numericLiteral (COMMA compression = numericLiteral)? RT_PRTHS ; +numericLiteral + : integerLiteral + | decimalLiteral + ; + // expressions expression : logicalExpression diff --git a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java index 690e45d67c..47db10c99b 100644 --- a/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java +++ b/ppl/src/main/java/org/opensearch/sql/ppl/parser/AstExpressionBuilder.java @@ -33,7 +33,6 @@ import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.LogicalXorContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.MultiFieldRelevanceFunctionContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.ParentheticValueExprContext; -import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.PercentileAggFunctionContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SingleFieldRelevanceFunctionContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SortFieldContext; import static org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser.SpanClauseContext; @@ -45,7 +44,6 @@ import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import java.util.Arrays; -import java.util.Collections; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -53,30 +51,7 @@ import org.antlr.v4.runtime.ParserRuleContext; import org.antlr.v4.runtime.RuleContext; import org.opensearch.sql.ast.dsl.AstDSL; -import org.opensearch.sql.ast.expression.AggregateFunction; -import org.opensearch.sql.ast.expression.Alias; -import org.opensearch.sql.ast.expression.AllFields; -import org.opensearch.sql.ast.expression.And; -import org.opensearch.sql.ast.expression.Argument; -import org.opensearch.sql.ast.expression.Cast; -import org.opensearch.sql.ast.expression.Compare; -import org.opensearch.sql.ast.expression.DataType; -import org.opensearch.sql.ast.expression.Field; -import org.opensearch.sql.ast.expression.Function; -import org.opensearch.sql.ast.expression.In; -import org.opensearch.sql.ast.expression.Interval; -import org.opensearch.sql.ast.expression.IntervalUnit; -import org.opensearch.sql.ast.expression.Let; -import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.ast.expression.Not; -import org.opensearch.sql.ast.expression.Or; -import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.ast.expression.RelevanceFieldList; -import org.opensearch.sql.ast.expression.Span; -import org.opensearch.sql.ast.expression.SpanUnit; -import org.opensearch.sql.ast.expression.UnresolvedArgument; -import org.opensearch.sql.ast.expression.UnresolvedExpression; -import org.opensearch.sql.ast.expression.Xor; +import org.opensearch.sql.ast.expression.*; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParser; import org.opensearch.sql.ppl.antlr.parser.OpenSearchPPLParserBaseVisitor; @@ -183,11 +158,16 @@ public UnresolvedExpression visitDistinctCountFunctionCall(DistinctCountFunction } @Override - public UnresolvedExpression visitPercentileAggFunction(PercentileAggFunctionContext ctx) { + public UnresolvedExpression visitPercentileApproxFunctionCall( + OpenSearchPPLParser.PercentileApproxFunctionCallContext ctx) { + ImmutableList.Builder<UnresolvedExpression> builder = ImmutableList.builder(); + builder.add(new UnresolvedArgument("percent", visit(ctx.percentileApproxFunction().percent))); + if (ctx.percentileApproxFunction().compression != null) { + builder.add( + new UnresolvedArgument("compression", visit(ctx.percentileApproxFunction().compression))); + } return new AggregateFunction( - ctx.PERCENTILE().getText(), - visit(ctx.aggField), - Collections.singletonList(new Argument("rank", (Literal) visit(ctx.value)))); + "percentile", visit(ctx.percentileApproxFunction().aggField), builder.build()); } @Override diff --git a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java index aa25a6fcc6..7bcb87d193 100644 --- a/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java +++ b/ppl/src/test/java/org/opensearch/sql/ppl/parser/AstExpressionBuilderTest.java @@ -333,13 +333,40 @@ public void testStdDevPAggregationShouldPass() { @Test public void testPercentileAggFuncExpr() { assertEqual( - "source=t | stats percentile<1>(a)", + "source=t | stats percentile(a, 1)", agg( relation("t"), exprList( alias( - "percentile<1>(a)", - aggregate("percentile", field("a"), argument("rank", intLiteral(1))))), + "percentile(a, 1)", + aggregate("percentile", field("a"), unresolvedArg("percent", intLiteral(1))))), + emptyList(), + emptyList(), + defaultStatsArgs())); + assertEqual( + "source=t | stats percentile(a, 1.0)", + agg( + relation("t"), + exprList( + alias( + "percentile(a, 1.0)", + aggregate( + "percentile", field("a"), unresolvedArg("percent", doubleLiteral(1D))))), + emptyList(), + emptyList(), + defaultStatsArgs())); + assertEqual( + "source=t | stats percentile(a, 1.0, 100)", + agg( + relation("t"), + exprList( + alias( + "percentile(a, 1.0, 100)", + aggregate( + "percentile", + field("a"), + unresolvedArg("percent", doubleLiteral(1D)), + unresolvedArg("compression", intLiteral(100))))), emptyList(), emptyList(), defaultStatsArgs())); @@ -569,7 +596,8 @@ public void canBuildQuery_stringRelevanceFunctionWithArguments() { @Test public void functionNameCanBeUsedAsIdentifier() { assertFunctionNameCouldBeId( - "AVG | COUNT | SUM | MIN | MAX | VAR_SAMP | VAR_POP | STDDEV_SAMP | STDDEV_POP"); + "AVG | COUNT | SUM | MIN | MAX | VAR_SAMP | VAR_POP | STDDEV_SAMP | STDDEV_POP |" + + " PERCENTILE"); assertFunctionNameCouldBeId( "CURRENT_DATE | CURRENT_TIME | CURRENT_TIMESTAMP | LOCALTIME | LOCALTIMESTAMP | " + "UTC_TIMESTAMP | UTC_DATE | UTC_TIME | CURDATE | CURTIME | NOW"); diff --git a/sql/src/main/antlr/OpenSearchSQLLexer.g4 b/sql/src/main/antlr/OpenSearchSQLLexer.g4 index b65f60e289..ba7c5be85a 100644 --- a/sql/src/main/antlr/OpenSearchSQLLexer.g4 +++ b/sql/src/main/antlr/OpenSearchSQLLexer.g4 @@ -322,6 +322,8 @@ MULTI_MATCH: 'MULTI_MATCH'; MULTIMATCHQUERY: 'MULTIMATCHQUERY'; NESTED: 'NESTED'; PERCENTILES: 'PERCENTILES'; +PERCENTILE: 'PERCENTILE'; +PERCENTILE_APPROX: 'PERCENTILE_APPROX'; REGEXP_QUERY: 'REGEXP_QUERY'; REVERSE_NESTED: 'REVERSE_NESTED'; QUERY: 'QUERY'; diff --git a/sql/src/main/antlr/OpenSearchSQLParser.g4 b/sql/src/main/antlr/OpenSearchSQLParser.g4 index c16bc9805e..5f7361160b 100644 --- a/sql/src/main/antlr/OpenSearchSQLParser.g4 +++ b/sql/src/main/antlr/OpenSearchSQLParser.g4 @@ -185,6 +185,11 @@ decimalLiteral | TWO_DECIMAL ; +numericLiteral + : decimalLiteral + | realLiteral + ; + stringLiteral : STRING_LITERAL | DOUBLE_QUOTE_ID @@ -470,6 +475,12 @@ aggregateFunction : functionName = aggregationFunctionName LR_BRACKET functionArg RR_BRACKET # regularAggregateFunctionCall | COUNT LR_BRACKET STAR RR_BRACKET # countStarFunctionCall | COUNT LR_BRACKET DISTINCT functionArg RR_BRACKET # distinctCountFunctionCall + | percentileApproxFunction # percentileApproxFunctionCall + ; + +percentileApproxFunction + : (PERCENTILE | PERCENTILE_APPROX) LR_BRACKET aggField = functionArg + COMMA percent = numericLiteral (COMMA compression = numericLiteral)? RR_BRACKET ; filterClause @@ -752,8 +763,7 @@ relevanceFieldAndWeight ; relevanceFieldWeight - : realLiteral - | decimalLiteral + : numericLiteral ; relevanceField diff --git a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java index 6dd1e02a1d..d1c0be98b2 100644 --- a/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java +++ b/sql/src/main/java/org/opensearch/sql/sql/parser/AstExpressionBuilder.java @@ -78,30 +78,11 @@ import org.apache.commons.lang3.tuple.ImmutablePair; import org.apache.commons.lang3.tuple.Pair; import org.opensearch.sql.ast.dsl.AstDSL; -import org.opensearch.sql.ast.expression.AggregateFunction; -import org.opensearch.sql.ast.expression.AllFields; -import org.opensearch.sql.ast.expression.And; -import org.opensearch.sql.ast.expression.Case; -import org.opensearch.sql.ast.expression.Cast; -import org.opensearch.sql.ast.expression.DataType; -import org.opensearch.sql.ast.expression.Function; -import org.opensearch.sql.ast.expression.HighlightFunction; -import org.opensearch.sql.ast.expression.Interval; -import org.opensearch.sql.ast.expression.IntervalUnit; -import org.opensearch.sql.ast.expression.Literal; -import org.opensearch.sql.ast.expression.NestedAllTupleFields; -import org.opensearch.sql.ast.expression.Not; -import org.opensearch.sql.ast.expression.Or; -import org.opensearch.sql.ast.expression.QualifiedName; -import org.opensearch.sql.ast.expression.RelevanceFieldList; -import org.opensearch.sql.ast.expression.ScoreFunction; -import org.opensearch.sql.ast.expression.UnresolvedArgument; -import org.opensearch.sql.ast.expression.UnresolvedExpression; -import org.opensearch.sql.ast.expression.When; -import org.opensearch.sql.ast.expression.WindowFunction; +import org.opensearch.sql.ast.expression.*; import org.opensearch.sql.ast.tree.Sort.SortOption; import org.opensearch.sql.common.utils.StringUtils; import org.opensearch.sql.expression.function.BuiltinFunctionName; +import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.AlternateMultiMatchQueryContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.AndExpressionContext; import org.opensearch.sql.sql.antlr.parser.OpenSearchSQLParser.ColumnNameContext; @@ -411,6 +392,26 @@ public UnresolvedExpression visitConvertedDataType(ConvertedDataTypeContext ctx) return AstDSL.stringLiteral(ctx.getText()); } + @Override + public UnresolvedExpression visitPercentileApproxFunctionCall( + OpenSearchSQLParser.PercentileApproxFunctionCallContext ctx) { + ImmutableList.Builder<UnresolvedExpression> builder = ImmutableList.builder(); + builder.add( + new UnresolvedArgument( + "percent", + AstDSL.doubleLiteral( + Double.valueOf(ctx.percentileApproxFunction().percent.getText())))); + if (ctx.percentileApproxFunction().compression != null) { + builder.add( + new UnresolvedArgument( + "compression", + AstDSL.doubleLiteral( + Double.valueOf(ctx.percentileApproxFunction().compression.getText())))); + } + return new AggregateFunction( + "percentile", visit(ctx.percentileApproxFunction().aggField), builder.build()); + } + @Override public UnresolvedExpression visitNoFieldRelevanceFunction(NoFieldRelevanceFunctionContext ctx) { return new Function( diff --git a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java index f2e7fdb2d8..e89f2af9b0 100644 --- a/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java +++ b/sql/src/test/java/org/opensearch/sql/sql/parser/AstExpressionBuilderTest.java @@ -408,6 +408,26 @@ public void filteredDistinctCount() { buildExprAst("count(distinct name) filter(where age > 30)")); } + @Test + public void canBuildPercentile() { + Object expected = + aggregate("percentile", qualifiedName("age"), unresolvedArg("percent", doubleLiteral(50D))); + assertEquals(expected, buildExprAst("percentile(age, 50)")); + assertEquals(expected, buildExprAst("percentile(age, 50.0)")); + } + + @Test + public void canBuildPercentileWithCompression() { + Object expected = + aggregate( + "percentile", + qualifiedName("age"), + unresolvedArg("percent", doubleLiteral(50D)), + unresolvedArg("compression", doubleLiteral(100D))); + assertEquals(expected, buildExprAst("percentile(age, 50, 100)")); + assertEquals(expected, buildExprAst("percentile(age, 50.0, 100.0)")); + } + @Test public void matchPhraseQueryAllParameters() { assertEquals(