From f79c62157dda38e176d8f2795ca820d060b47984 Mon Sep 17 00:00:00 2001 From: Pablo Machado Date: Wed, 31 Jul 2024 12:08:28 +0200 Subject: [PATCH] ESQL: Add `MV_PSERIES_WEIGHTED_SUM` for score calculations used by security solution (#109017) * Create MV_RIEMANN_ZETA scalar multivalue function --------- Co-authored-by: Nik Everett --- docs/changelog/109017.yaml | 6 + .../mv_pseries_weighted_sum.asciidoc | 5 + .../examples/mv_pseries_weighted_sum.asciidoc | 13 ++ .../definition/mv_pseries_weighted_sum.json | 29 +++ .../kibana/docs/mv_pseries_weighted_sum.md | 12 ++ .../layout/mv_pseries_weighted_sum.asciidoc | 15 ++ .../esql/functions/mv-functions.asciidoc | 2 + .../mv_pseries_weighted_sum.asciidoc | 9 + .../signature/mv_pseries_weighted_sum.svg | 1 + .../types/mv_pseries_weighted_sum.asciidoc | 9 + .../xpack/esql/CsvTestsDataLoader.java | 2 + .../src/main/resources/alerts.csv | 11 ++ .../src/main/resources/mapping-alerts.json | 10 + .../src/main/resources/meta.csv-spec | 6 +- .../mv_pseries_weighted_sum.csv-spec | 89 +++++++++ .../MvPSeriesWeightedSumDoubleEvaluator.java | 105 +++++++++++ .../xpack/esql/action/EsqlCapabilities.java | 9 + .../function/EsqlFunctionRegistry.java | 3 + .../AbstractMultivalueFunction.java | 1 + .../multivalue/MvPSeriesWeightedSum.java | 174 ++++++++++++++++++ .../function/scalar/package-info.java | 10 + .../expression/function/TestCaseSupplier.java | 10 +- ...vPSeriesWeightedSumSerializationTests.java | 39 ++++ .../multivalue/MvPSeriesWeightedSumTests.java | 71 +++++++ 24 files changed, 636 insertions(+), 5 deletions(-) create mode 100644 docs/changelog/109017.yaml create mode 100644 docs/reference/esql/functions/description/mv_pseries_weighted_sum.asciidoc create mode 100644 docs/reference/esql/functions/examples/mv_pseries_weighted_sum.asciidoc create mode 100644 docs/reference/esql/functions/kibana/definition/mv_pseries_weighted_sum.json create mode 100644 docs/reference/esql/functions/kibana/docs/mv_pseries_weighted_sum.md create mode 100644 docs/reference/esql/functions/layout/mv_pseries_weighted_sum.asciidoc create mode 100644 docs/reference/esql/functions/parameters/mv_pseries_weighted_sum.asciidoc create mode 100644 docs/reference/esql/functions/signature/mv_pseries_weighted_sum.svg create mode 100644 docs/reference/esql/functions/types/mv_pseries_weighted_sum.asciidoc create mode 100644 x-pack/plugin/esql/qa/testFixtures/src/main/resources/alerts.csv create mode 100644 x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-alerts.json create mode 100644 x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_pseries_weighted_sum.csv-spec create mode 100644 x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumDoubleEvaluator.java create mode 100644 x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumSerializationTests.java create mode 100644 x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java diff --git a/docs/changelog/109017.yaml b/docs/changelog/109017.yaml new file mode 100644 index 0000000000000..80bcdd6fc0e25 --- /dev/null +++ b/docs/changelog/109017.yaml @@ -0,0 +1,6 @@ +pr: 109017 +summary: "ESQL: Add `MV_PSERIES_WEIGHTED_SUM` for score calculations used by security\ + \ solution" +area: ES|QL +type: "feature" +issues: [ ] diff --git a/docs/reference/esql/functions/description/mv_pseries_weighted_sum.asciidoc b/docs/reference/esql/functions/description/mv_pseries_weighted_sum.asciidoc new file mode 100644 index 0000000000000..d464689f40a01 --- /dev/null +++ b/docs/reference/esql/functions/description/mv_pseries_weighted_sum.asciidoc @@ -0,0 +1,5 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Description* + +Converts a multivalued expression into a single-valued column by multiplying every element on the input list by its corresponding term in P-Series and computing the sum. diff --git a/docs/reference/esql/functions/examples/mv_pseries_weighted_sum.asciidoc b/docs/reference/esql/functions/examples/mv_pseries_weighted_sum.asciidoc new file mode 100644 index 0000000000000..bce4deb1f5225 --- /dev/null +++ b/docs/reference/esql/functions/examples/mv_pseries_weighted_sum.asciidoc @@ -0,0 +1,13 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Example* + +[source.merge.styled,esql] +---- +include::{esql-specs}/mv_pseries_weighted_sum.csv-spec[tag=example] +---- +[%header.monospaced.styled,format=dsv,separator=|] +|=== +include::{esql-specs}/mv_pseries_weighted_sum.csv-spec[tag=example-result] +|=== + diff --git a/docs/reference/esql/functions/kibana/definition/mv_pseries_weighted_sum.json b/docs/reference/esql/functions/kibana/definition/mv_pseries_weighted_sum.json new file mode 100644 index 0000000000000..626f7befbb12e --- /dev/null +++ b/docs/reference/esql/functions/kibana/definition/mv_pseries_weighted_sum.json @@ -0,0 +1,29 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", + "type" : "eval", + "name" : "mv_pseries_weighted_sum", + "description" : "Converts a multivalued expression into a single-valued column by multiplying every element on the input list by its corresponding term in P-Series and computing the sum.", + "signatures" : [ + { + "params" : [ + { + "name" : "number", + "type" : "double", + "optional" : false, + "description" : "Multivalue expression." + }, + { + "name" : "p", + "type" : "double", + "optional" : false, + "description" : "It is a constant number that represents the 'p' parameter in the P-Series. It impacts every element's contribution to the weighted sum." + } + ], + "variadic" : false, + "returnType" : "double" + } + ], + "examples" : [ + "ROW a = [70.0, 45.0, 21.0, 21.0, 21.0]\n| EVAL sum = MV_PSERIES_WEIGHTED_SUM(a, 1.5)\n| KEEP sum" + ] +} diff --git a/docs/reference/esql/functions/kibana/docs/mv_pseries_weighted_sum.md b/docs/reference/esql/functions/kibana/docs/mv_pseries_weighted_sum.md new file mode 100644 index 0000000000000..fbeb310449b9b --- /dev/null +++ b/docs/reference/esql/functions/kibana/docs/mv_pseries_weighted_sum.md @@ -0,0 +1,12 @@ + + +### MV_PSERIES_WEIGHTED_SUM +Converts a multivalued expression into a single-valued column by multiplying every element on the input list by its corresponding term in P-Series and computing the sum. + +``` +ROW a = [70.0, 45.0, 21.0, 21.0, 21.0] +| EVAL sum = MV_PSERIES_WEIGHTED_SUM(a, 1.5) +| KEEP sum +``` diff --git a/docs/reference/esql/functions/layout/mv_pseries_weighted_sum.asciidoc b/docs/reference/esql/functions/layout/mv_pseries_weighted_sum.asciidoc new file mode 100644 index 0000000000000..7c14ecbc3c935 --- /dev/null +++ b/docs/reference/esql/functions/layout/mv_pseries_weighted_sum.asciidoc @@ -0,0 +1,15 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +[discrete] +[[esql-mv_pseries_weighted_sum]] +=== `MV_PSERIES_WEIGHTED_SUM` + +*Syntax* + +[.text-center] +image::esql/functions/signature/mv_pseries_weighted_sum.svg[Embedded,opts=inline] + +include::../parameters/mv_pseries_weighted_sum.asciidoc[] +include::../description/mv_pseries_weighted_sum.asciidoc[] +include::../types/mv_pseries_weighted_sum.asciidoc[] +include::../examples/mv_pseries_weighted_sum.asciidoc[] diff --git a/docs/reference/esql/functions/mv-functions.asciidoc b/docs/reference/esql/functions/mv-functions.asciidoc index 0f4f6233d446c..bd5f14cdd3557 100644 --- a/docs/reference/esql/functions/mv-functions.asciidoc +++ b/docs/reference/esql/functions/mv-functions.asciidoc @@ -18,6 +18,7 @@ * <> * <> * <> +* <> * <> * <> * <> @@ -34,6 +35,7 @@ include::layout/mv_last.asciidoc[] include::layout/mv_max.asciidoc[] include::layout/mv_median.asciidoc[] include::layout/mv_min.asciidoc[] +include::layout/mv_pseries_weighted_sum.asciidoc[] include::layout/mv_slice.asciidoc[] include::layout/mv_sort.asciidoc[] include::layout/mv_sum.asciidoc[] diff --git a/docs/reference/esql/functions/parameters/mv_pseries_weighted_sum.asciidoc b/docs/reference/esql/functions/parameters/mv_pseries_weighted_sum.asciidoc new file mode 100644 index 0000000000000..3a828f1464824 --- /dev/null +++ b/docs/reference/esql/functions/parameters/mv_pseries_weighted_sum.asciidoc @@ -0,0 +1,9 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Parameters* + +`number`:: +Multivalue expression. + +`p`:: +It is a constant number that represents the 'p' parameter in the P-Series. It impacts every element's contribution to the weighted sum. diff --git a/docs/reference/esql/functions/signature/mv_pseries_weighted_sum.svg b/docs/reference/esql/functions/signature/mv_pseries_weighted_sum.svg new file mode 100644 index 0000000000000..7e3b42161e52c --- /dev/null +++ b/docs/reference/esql/functions/signature/mv_pseries_weighted_sum.svg @@ -0,0 +1 @@ +MV_PSERIES_WEIGHTED_SUM(number,p) \ No newline at end of file diff --git a/docs/reference/esql/functions/types/mv_pseries_weighted_sum.asciidoc b/docs/reference/esql/functions/types/mv_pseries_weighted_sum.asciidoc new file mode 100644 index 0000000000000..f28e61f17aa33 --- /dev/null +++ b/docs/reference/esql/functions/types/mv_pseries_weighted_sum.asciidoc @@ -0,0 +1,9 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Supported types* + +[%header.monospaced.styled,format=dsv,separator=|] +|=== +number | p | result +double | double | double +|=== diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java index f9b768d67d574..628c321425f72 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestsDataLoader.java @@ -55,6 +55,7 @@ public class CsvTestsDataLoader { private static final TestsDataset HOSTS = new TestsDataset("hosts", "mapping-hosts.json", "hosts.csv"); private static final TestsDataset APPS = new TestsDataset("apps", "mapping-apps.json", "apps.csv"); private static final TestsDataset LANGUAGES = new TestsDataset("languages", "mapping-languages.json", "languages.csv"); + private static final TestsDataset ALERTS = new TestsDataset("alerts", "mapping-alerts.json", "alerts.csv"); private static final TestsDataset UL_LOGS = new TestsDataset("ul_logs", "mapping-ul_logs.json", "ul_logs.csv"); private static final TestsDataset SAMPLE_DATA = new TestsDataset("sample_data", "mapping-sample_data.json", "sample_data.csv"); private static final TestsDataset SAMPLE_DATA_STR = new TestsDataset( @@ -106,6 +107,7 @@ public class CsvTestsDataLoader { Map.entry(LANGUAGES.indexName, LANGUAGES), Map.entry(UL_LOGS.indexName, UL_LOGS), Map.entry(SAMPLE_DATA.indexName, SAMPLE_DATA), + Map.entry(ALERTS.indexName, ALERTS), Map.entry(SAMPLE_DATA_STR.indexName, SAMPLE_DATA_STR), Map.entry(SAMPLE_DATA_TS_LONG.indexName, SAMPLE_DATA_TS_LONG), Map.entry(CLIENT_IPS.indexName, CLIENT_IPS), diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/alerts.csv b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/alerts.csv new file mode 100644 index 0000000000000..dbf13fd89176c --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/alerts.csv @@ -0,0 +1,11 @@ +host.name:keyword,kibana.alert.risk_score:double +test-host-1,21.0 +test-host-2,17.0 +test-host-2,23.0 +test-host-1,45.0 +test-host-2,12.0 +test-host-2,16.0 +test-host-1,21.0 +test-host-1,70.0 +test-host-1,21.0 +test-host-2,5.0 diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-alerts.json b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-alerts.json new file mode 100644 index 0000000000000..a5ceac91abe50 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mapping-alerts.json @@ -0,0 +1,10 @@ +{ + "properties": { + "host.name": { + "type": "keyword" + }, + "kibana.alert.risk_score": { + "type": "double" + } + } +} diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec index 6912b77be4c58..e2ad918fbb5f3 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec @@ -53,6 +53,7 @@ double e() "boolean|date|double|integer|ip|keyword|long|text|unsigned_long|version mv_max(field:boolean|date|double|integer|ip|keyword|long|text|unsigned_long|version)" "double|integer|long|unsigned_long mv_median(number:double|integer|long|unsigned_long)" "boolean|date|double|integer|ip|keyword|long|text|unsigned_long|version mv_min(field:boolean|date|double|integer|ip|keyword|long|text|unsigned_long|version)" +"double mv_pseries_weighted_sum(number:double, p:double)" "boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version mv_slice(field:boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version, start:integer, ?end:integer)" "boolean|date|double|integer|ip|keyword|long|text|version mv_sort(field:boolean|date|double|integer|ip|keyword|long|text|version, ?order:keyword)" "double|integer|long|unsigned_long mv_sum(number:double|integer|long|unsigned_long)" @@ -174,6 +175,7 @@ mv_last |field |"boolean|cartesian_point|car mv_max |field |"boolean|date|double|integer|ip|keyword|long|text|unsigned_long|version" |Multivalue expression. mv_median |number |"double|integer|long|unsigned_long" |Multivalue expression. mv_min |field |"boolean|date|double|integer|ip|keyword|long|text|unsigned_long|version" |Multivalue expression. +mv_pseries_wei|[number, p] |[double, double] |[Multivalue expression., It is a constant number that represents the 'p' parameter in the P-Series. It impacts every element's contribution to the weighted sum.] mv_slice |[field, start, end] |["boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version", integer, integer]|[Multivalue expression. If `null`\, the function returns `null`., Start position. If `null`\, the function returns `null`. The start argument can be negative. An index of -1 is used to specify the last value in the list., End position(included). Optional; if omitted\, the position at `start` is returned. The end argument can be negative. An index of -1 is used to specify the last value in the list.] mv_sort |[field, order] |["boolean|date|double|integer|ip|keyword|long|text|version", keyword] |[Multivalue expression. If `null`\, the function returns `null`., Sort order. The valid options are ASC and DESC\, the default is ASC.] mv_sum |number |"double|integer|long|unsigned_long" |Multivalue expression. @@ -296,6 +298,7 @@ mv_last |Converts a multivalue expression into a single valued column cont mv_max |Converts a multivalued expression into a single valued column containing the maximum value. mv_median |Converts a multivalued field into a single valued field containing the median value. mv_min |Converts a multivalued expression into a single valued column containing the minimum value. +mv_pseries_wei|Converts a multivalued expression into a single-valued column by multiplying every element on the input list by its corresponding term in P-Series and computing the sum. mv_slice |Returns a subset of the multivalued field using the start and end index values. mv_sort |Sorts a multivalued field in lexicographical order. mv_sum |Converts a multivalued field into a single valued field containing the sum of all of the values. @@ -419,6 +422,7 @@ mv_last |"boolean|cartesian_point|cartesian_shape|date|double|geo_point|ge mv_max |"boolean|date|double|integer|ip|keyword|long|text|unsigned_long|version" |false |false |false mv_median |"double|integer|long|unsigned_long" |false |false |false mv_min |"boolean|date|double|integer|ip|keyword|long|text|unsigned_long|version" |false |false |false +mv_pseries_wei|"double" |[false, false] |false |false mv_slice |"boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version" |[false, false, true] |false |false mv_sort |"boolean|date|double|integer|ip|keyword|long|text|version" |[false, true] |false |false mv_sum |"double|integer|long|unsigned_long" |false |false |false @@ -497,5 +501,5 @@ countFunctions#[skip:-8.15.99] meta functions | stats a = count(*), b = count(*), c = count(*) | mv_expand c; a:long | b:long | c:long -113 | 113 | 113 +114 | 114 | 114 ; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_pseries_weighted_sum.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_pseries_weighted_sum.csv-spec new file mode 100644 index 0000000000000..4d8ffd1136908 --- /dev/null +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/mv_pseries_weighted_sum.csv-spec @@ -0,0 +1,89 @@ +default +required_capability: mv_pseries_weighted_sum + +// tag::example[] +ROW a = [70.0, 45.0, 21.0, 21.0, 21.0] +| EVAL sum = MV_PSERIES_WEIGHTED_SUM(a, 1.5) +| KEEP sum +// end::example[] +; + +// tag::example-result[] +sum:double +94.45465156212452 +// end::example-result[] +; + +oneElement +required_capability: mv_pseries_weighted_sum + +ROW data = [3.0] +| EVAL score = MV_PSERIES_WEIGHTED_SUM(data, 9999.9) +| KEEP score; + +score:double +3.0 +; + +zeroP +required_capability: mv_pseries_weighted_sum + +ROW data = [3.0, 10.0, 15.0] +| EVAL score = MV_PSERIES_WEIGHTED_SUM(data, 0.0) +| KEEP score; + +score:double +28.0 +; + +negativeP +required_capability: mv_pseries_weighted_sum + +ROW data = [10.0, 5.0, 3.0] +| EVAL score = MV_PSERIES_WEIGHTED_SUM(data, -2.0) +| KEEP score; + +score:double +57.0 +; + +composed +required_capability: mv_pseries_weighted_sum + +ROW data = [21.0, 45.0, 21.0, 70.0, 21.0] +| EVAL sorted = MV_SORT(data, "desc") +| EVAL score = MV_PSERIES_WEIGHTED_SUM(sorted, 1.5) +| EVAL normalized_score = ROUND(100 * score / 261.2, 2) +| KEEP normalized_score, score; + +normalized_score:double|score:double +36.16 |94.45465156212452 +; + +multivalueAggregation +required_capability: mv_pseries_weighted_sum + +FROM alerts +| WHERE host.name is not null +| SORT host.name, kibana.alert.risk_score +| STATS score=MV_PSERIES_WEIGHTED_SUM( + TOP(kibana.alert.risk_score, 10000, "desc"), 1.5 +) BY host.name +| EVAL normalized_score = ROUND(100 * score / 261.2, 2) +| KEEP host.name, normalized_score, score; + +host.name:keyword|normalized_score:double|score:double +test-host-1 |36.16 |94.45465156212452 +test-host-2 |13.03 |34.036822671263614 +; + +asArgument +required_capability: mv_pseries_weighted_sum + +ROW data = [70.0, 45.0, 21.0, 21.0, 21.0] +| EVAL score = ROUND(MV_PSERIES_WEIGHTED_SUM(data, 1.5), 1) +| KEEP score; + +score:double +94.5 +; diff --git a/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumDoubleEvaluator.java b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumDoubleEvaluator.java new file mode 100644 index 0000000000000..c96599eaf8236 --- /dev/null +++ b/x-pack/plugin/esql/src/main/generated/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumDoubleEvaluator.java @@ -0,0 +1,105 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import java.lang.Override; +import java.lang.String; +import java.util.function.Function; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.core.Releasables; +import org.elasticsearch.search.aggregations.metrics.CompensatedSum; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.function.Warnings; + +/** + * {@link EvalOperator.ExpressionEvaluator} implementation for {@link MvPSeriesWeightedSum}. + * This class is generated. Do not edit it. + */ +public final class MvPSeriesWeightedSumDoubleEvaluator implements EvalOperator.ExpressionEvaluator { + private final Warnings warnings; + + private final EvalOperator.ExpressionEvaluator block; + + private final CompensatedSum sum; + + private final double p; + + private final DriverContext driverContext; + + public MvPSeriesWeightedSumDoubleEvaluator(Source source, EvalOperator.ExpressionEvaluator block, + CompensatedSum sum, double p, DriverContext driverContext) { + this.block = block; + this.sum = sum; + this.p = p; + this.driverContext = driverContext; + this.warnings = Warnings.createWarnings(driverContext.warningsMode(), source); + } + + @Override + public Block eval(Page page) { + try (DoubleBlock blockBlock = (DoubleBlock) block.eval(page)) { + return eval(page.getPositionCount(), blockBlock); + } + } + + public DoubleBlock eval(int positionCount, DoubleBlock blockBlock) { + try(DoubleBlock.Builder result = driverContext.blockFactory().newDoubleBlockBuilder(positionCount)) { + position: for (int p = 0; p < positionCount; p++) { + boolean allBlocksAreNulls = true; + if (!blockBlock.isNull(p)) { + allBlocksAreNulls = false; + } + if (allBlocksAreNulls) { + result.appendNull(); + continue position; + } + MvPSeriesWeightedSum.process(result, p, blockBlock, this.sum, this.p); + } + return result.build(); + } + } + + @Override + public String toString() { + return "MvPSeriesWeightedSumDoubleEvaluator[" + "block=" + block + ", p=" + p + "]"; + } + + @Override + public void close() { + Releasables.closeExpectNoException(block); + } + + static class Factory implements EvalOperator.ExpressionEvaluator.Factory { + private final Source source; + + private final EvalOperator.ExpressionEvaluator.Factory block; + + private final Function sum; + + private final double p; + + public Factory(Source source, EvalOperator.ExpressionEvaluator.Factory block, + Function sum, double p) { + this.source = source; + this.block = block; + this.sum = sum; + this.p = p; + } + + @Override + public MvPSeriesWeightedSumDoubleEvaluator get(DriverContext context) { + return new MvPSeriesWeightedSumDoubleEvaluator(source, block.get(context), sum.apply(context), p, context); + } + + @Override + public String toString() { + return "MvPSeriesWeightedSumDoubleEvaluator[" + "block=" + block + ", p=" + p + "]"; + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 76562bbe6ebf0..6e57794bbd6aa 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -183,6 +183,15 @@ public enum Cap { */ FIXED_PUSHDOWN_PAST_PROJECT, + /** + * Adds the {@code MV_PSERIES_WEIGHTED_SUM} function for converting sorted lists of numbers into + * a bounded score. This is a generalization of the + * riemann zeta function but we + * don't name it that because we don't support complex numbers and don't want to make folks think + * of mystical number theory things. This is just a weighted sum that is adjacent to magic. + */ + MV_PSERIES_WEIGHTED_SUM, + /** * Support for match operator */ diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java index de7e63e16d53e..5a5c407c86306 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/EsqlFunctionRegistry.java @@ -95,6 +95,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMax; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMedian; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin; +import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvPSeriesWeightedSum; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSlice; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSort; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvSum; @@ -360,11 +361,13 @@ private FunctionDefinition[][] functions() { def(MvMax.class, MvMax::new, "mv_max"), def(MvMedian.class, MvMedian::new, "mv_median"), def(MvMin.class, MvMin::new, "mv_min"), + def(MvPSeriesWeightedSum.class, MvPSeriesWeightedSum::new, "mv_pseries_weighted_sum"), def(MvSort.class, MvSort::new, "mv_sort"), def(MvSlice.class, MvSlice::new, "mv_slice"), def(MvZip.class, MvZip::new, "mv_zip"), def(MvSum.class, MvSum::new, "mv_sum"), def(Split.class, Split::new, "split") } }; + } private static FunctionDefinition[][] snapshotFunctions() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java index cffb208940aa5..90810d282ca52 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/AbstractMultivalueFunction.java @@ -44,6 +44,7 @@ public static List getNamedWriteables() { MvMax.ENTRY, MvMedian.ENTRY, MvMin.ENTRY, + MvPSeriesWeightedSum.ENTRY, MvSlice.ENTRY, MvSort.ENTRY, MvSum.ENTRY, diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java new file mode 100644 index 0000000000000..60eab9fd4ad74 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSum.java @@ -0,0 +1,174 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.compute.ann.Evaluator; +import org.elasticsearch.compute.ann.Fixed; +import org.elasticsearch.compute.data.DoubleBlock; +import org.elasticsearch.compute.operator.EvalOperator; +import org.elasticsearch.compute.operator.EvalOperator.ExpressionEvaluator; +import org.elasticsearch.search.aggregations.metrics.CompensatedSum; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; +import org.elasticsearch.xpack.esql.core.tree.NodeInfo; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.evaluator.mapper.EvaluatorMapper; +import org.elasticsearch.xpack.esql.expression.function.Example; +import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; +import org.elasticsearch.xpack.esql.expression.function.Param; +import org.elasticsearch.xpack.esql.expression.function.scalar.EsqlScalarFunction; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.planner.PlannerUtils; + +import java.io.IOException; +import java.util.Arrays; +import java.util.List; +import java.util.function.Function; + +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.SECOND; +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.isNotNullAndFoldable; +import static org.elasticsearch.xpack.esql.core.type.DataType.DOUBLE; + +/** + * Reduce a multivalued field to a single valued field containing the weighted sum of all element applying the P series function. + */ +public class MvPSeriesWeightedSum extends EsqlScalarFunction implements EvaluatorMapper { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Expression.class, + "MvPSeriesWeightedSum", + MvPSeriesWeightedSum::new + ); + + private final Expression field, p; + + @FunctionInfo( + returnType = { "double" }, + + description = "Converts a multivalued expression into a single-valued column by multiplying every " + + "element on the input list by its corresponding term in P-Series and computing the sum.", + examples = @Example(file = "mv_pseries_weighted_sum", tag = "example") + ) + public MvPSeriesWeightedSum( + Source source, + @Param(name = "number", type = { "double" }, description = "Multivalue expression.") Expression field, + @Param( + name = "p", + type = { "double" }, + description = "It is a constant number that represents the 'p' parameter in the P-Series. " + + "It impacts every element's contribution to the weighted sum." + ) Expression p + ) { + super(source, Arrays.asList(field, p)); + this.field = field; + this.p = p; + } + + private MvPSeriesWeightedSum(StreamInput in) throws IOException { + this(Source.readFrom((PlanStreamInput) in), in.readNamedWriteable(Expression.class), in.readNamedWriteable(Expression.class)); + } + + @Override + protected TypeResolution resolveType() { + if (childrenResolved() == false) { + return new TypeResolution("Unresolved children"); + } + + TypeResolution resolution = TypeResolutions.isType(field, dt -> dt == DOUBLE, sourceText(), FIRST, "double"); + if (resolution.unresolved()) { + return resolution; + } + + resolution = TypeResolutions.isType(p, dt -> dt == DOUBLE, sourceText(), SECOND, "double") + .and(isNotNullAndFoldable(p, sourceText(), SECOND)); + + if (resolution.unresolved()) { + return resolution; + } + + return resolution; + } + + @Override + public boolean foldable() { + return field.foldable() && p.foldable(); + } + + @Override + public EvalOperator.ExpressionEvaluator.Factory toEvaluator(Function toEvaluator) { + return switch (PlannerUtils.toElementType(field.dataType())) { + case DOUBLE -> new MvPSeriesWeightedSumDoubleEvaluator.Factory( + source(), + toEvaluator.apply(field), + ctx -> new CompensatedSum(), + (Double) p.fold() + ); + case NULL -> EvalOperator.CONSTANT_NULL_FACTORY; + default -> throw EsqlIllegalArgumentException.illegalDataType(field.dataType()); + }; + } + + @Override + public Expression replaceChildren(List newChildren) { + return new MvPSeriesWeightedSum(source(), newChildren.get(0), newChildren.get(1)); + } + + @Override + protected NodeInfo info() { + return NodeInfo.create(this, MvPSeriesWeightedSum::new, field, p); + } + + @Override + public DataType dataType() { + return field.dataType(); + } + + @Evaluator(extraName = "Double") + static void process( + DoubleBlock.Builder builder, + int position, + DoubleBlock block, + @Fixed(includeInToString = false, build = true) CompensatedSum sum, + @Fixed double p + ) { + sum.reset(0, 0); + int start = block.getFirstValueIndex(position); + int end = block.getValueCount(position) + start; + + for (int i = start; i < end; i++) { + double current_score = block.getDouble(i) / Math.pow(i - start + 1, p); + sum.add(current_score); + } + builder.appendDouble(sum.value()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + source().writeTo(out); + out.writeNamedWriteable(field); + out.writeNamedWriteable(p); + } + + Expression field() { + return field; + } + + Expression p() { + return p; + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/package-info.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/package-info.java index 4f9219247d5c2..b4781c7e41f98 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/package-info.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/package-info.java @@ -184,6 +184,16 @@ * looks ok. * *
  • + * Let's finish up the code by making the tests backwards compatible. Since this is a new + * feature we just have to convince the tests not to run in a cluster that includes older + * versions of Elasticsearch. We do that with a {@link org.elasticsearch.rest.RestHandler#supportedCapabilities capability} + * on the REST handler. ESQL has a ton of capabilities so we list them + * all in {@link org.elasticsearch.xpack.esql.action.EsqlCapabilities}. Add a new one + * for your function. Now add something like {@code required_capability: my_function} + * to all of your csv-spec tests. Run those csv-spec tests as integration tests to double + * check that they run on the main branch. + *
  • + *
  • * Open the PR. The subject and description of the PR are important because those'll turn * into the commit message we see in the commit history. Good PR descriptions make me very * happy. But functions don't need an essay. diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java index cd375b8c53595..3585e58bf97ab 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java @@ -1305,11 +1305,11 @@ public static class TestCase { private final Class foldingExceptionClass; private final String foldingExceptionMessage; - public TestCase(List data, String evaluatorToString, DataType expectedType, Matcher matcher) { + public TestCase(List data, String evaluatorToString, DataType expectedType, Matcher matcher) { this(data, equalTo(evaluatorToString), expectedType, matcher); } - public TestCase(List data, Matcher evaluatorToString, DataType expectedType, Matcher matcher) { + public TestCase(List data, Matcher evaluatorToString, DataType expectedType, Matcher matcher) { this(data, evaluatorToString, expectedType, matcher, null, null, null, null); } @@ -1321,7 +1321,7 @@ public static TestCase typeError(List data, String expectedTypeError) List data, Matcher evaluatorToString, DataType expectedType, - Matcher matcher, + Matcher matcher, String[] expectedWarnings, String expectedTypeError, Class foldingExceptionClass, @@ -1331,7 +1331,9 @@ public static TestCase typeError(List data, String expectedTypeError) this.data = data; this.evaluatorToString = evaluatorToString; this.expectedType = expectedType; - this.matcher = matcher; + @SuppressWarnings("unchecked") + Matcher downcast = (Matcher) matcher; + this.matcher = downcast; this.expectedWarnings = expectedWarnings; this.expectedTypeError = expectedTypeError; this.canBuildEvaluator = data.stream().allMatch(d -> d.forceLiteral || DataType.isRepresentable(d.type)); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumSerializationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumSerializationTests.java new file mode 100644 index 0000000000000..dcb525c79b272 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumSerializationTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.expression.AbstractExpressionSerializationTests; + +import java.io.IOException; + +public class MvPSeriesWeightedSumSerializationTests extends AbstractExpressionSerializationTests { + @Override + protected MvPSeriesWeightedSum createTestInstance() { + Source source = randomSource(); + Expression field = randomChild(); + Expression p = randomChild(); + + return new MvPSeriesWeightedSum(source, field, p); + } + + @Override + protected MvPSeriesWeightedSum mutateInstance(MvPSeriesWeightedSum instance) throws IOException { + Source source = instance.source(); + Expression field = instance.field(); + Expression p = instance.p(); + + switch (between(0, 1)) { + case 0 -> field = randomValueOtherThan(field, AbstractExpressionSerializationTests::randomChild); + case 1 -> p = randomValueOtherThan(p, AbstractExpressionSerializationTests::randomChild); + + } + return new MvPSeriesWeightedSum(source, field, p); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java new file mode 100644 index 0000000000000..d7a2b530007ad --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvPSeriesWeightedSumTests.java @@ -0,0 +1,71 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.scalar.multivalue; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.closeTo; + +public class MvPSeriesWeightedSumTests extends AbstractScalarFunctionTestCase { + public MvPSeriesWeightedSumTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + List cases = new ArrayList<>(); + + doubles(cases); + + // TODO use parameterSuppliersFromTypedDataWithDefaultChecks instead of parameterSuppliersFromTypedData and fix errors + return parameterSuppliersFromTypedData(cases); + } + + @Override + protected Expression build(Source source, List args) { + return new MvPSeriesWeightedSum(source, args.get(0), args.get(1)); + } + + private static void doubles(List cases) { + + cases.add(new TestCaseSupplier(List.of(DataType.DOUBLE, DataType.DOUBLE), () -> { + List field = randomList(1, 10, () -> randomDouble()); + double p = randomDoubleBetween(-100.0, 100.0, true); + + return new TestCaseSupplier.TestCase( + List.of( + new TestCaseSupplier.TypedData(field, DataType.DOUBLE, "field"), + new TestCaseSupplier.TypedData(p, DataType.DOUBLE, "p").forceLiteral() + ), + "MvPSeriesWeightedSumDoubleEvaluator[block=Attribute[channel=0], p=" + p + "]", + DataType.DOUBLE, + closeTo(calcPSeriesWeightedSum(field, p), 0.00000001) + ); + })); + } + + private static double calcPSeriesWeightedSum(List field, double p) { + double sum = 0; + for (int i = 0; i < field.size(); i++) { + double current = field.get(i) / Math.pow(i + 1, p); + sum += current; + } + return sum; + } +}