diff --git a/.buildkite/pipelines/dra-workflow.yml b/.buildkite/pipelines/dra-workflow.yml index 32a2b7d22134a..bcc6c9c57d756 100644 --- a/.buildkite/pipelines/dra-workflow.yml +++ b/.buildkite/pipelines/dra-workflow.yml @@ -7,7 +7,7 @@ steps: image: family/elasticsearch-ubuntu-2204 machineType: custom-32-98304 buildDirectory: /dev/shm/bk - diskSizeGb: 250 + diskSizeGb: 350 - wait # The hadoop build depends on the ES artifact # So let's trigger the hadoop build any time we build a new staging artifact diff --git a/.buildkite/pipelines/pull-request/bwc-snapshots-windows.yml b/.buildkite/pipelines/pull-request/bwc-snapshots-windows.yml deleted file mode 100644 index d37bdf380f926..0000000000000 --- a/.buildkite/pipelines/pull-request/bwc-snapshots-windows.yml +++ /dev/null @@ -1,20 +0,0 @@ -config: - allow-labels: test-windows -steps: - - group: bwc-snapshots-windows - steps: - - label: "{{matrix.BWC_VERSION}} / bwc-snapshots-windows" - key: "bwc-snapshots-windows" - command: .\.buildkite\scripts\run-script.ps1 bash .buildkite/scripts/windows-run-gradle.sh - env: - GRADLE_TASK: "v{{matrix.BWC_VERSION}}#bwcTest" - timeout_in_minutes: 300 - matrix: - setup: - BWC_VERSION: $SNAPSHOT_BWC_VERSIONS - agents: - provider: gcp - image: family/elasticsearch-windows-2022 - machineType: custom-32-98304 - diskType: pd-ssd - diskSizeGb: 350 diff --git a/benchmarks/build.gradle b/benchmarks/build.gradle index 8753d4a4762b7..49e81a67e85f9 100644 --- a/benchmarks/build.gradle +++ b/benchmarks/build.gradle @@ -47,8 +47,8 @@ dependencies { api "org.openjdk.jmh:jmh-core:$versions.jmh" annotationProcessor "org.openjdk.jmh:jmh-generator-annprocess:$versions.jmh" // Dependencies of JMH - runtimeOnly 'net.sf.jopt-simple:jopt-simple:4.6' - runtimeOnly 'org.apache.commons:commons-math3:3.2' + runtimeOnly 'net.sf.jopt-simple:jopt-simple:5.0.4' + runtimeOnly 'org.apache.commons:commons-math3:3.6.1' } // enable the JMH's BenchmarkProcessor to generate the final benchmark classes diff --git a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/TermsReduceBenchmark.java b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/TermsReduceBenchmark.java index 230e0c7e546c2..691874c775302 100644 --- a/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/TermsReduceBenchmark.java +++ b/benchmarks/src/main/java/org/elasticsearch/benchmark/search/aggregations/TermsReduceBenchmark.java @@ -71,7 +71,7 @@ public class TermsReduceBenchmark { private final SearchPhaseController controller = new SearchPhaseController((task, req) -> new AggregationReduceContext.Builder() { @Override public AggregationReduceContext forPartialReduction() { - return new AggregationReduceContext.ForPartial(null, null, task, builder); + return new AggregationReduceContext.ForPartial(null, null, task, builder, b -> {}); } @Override diff --git a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/toolchain/OracleOpenJdkToolchainResolver.java b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/toolchain/OracleOpenJdkToolchainResolver.java index d0c7e9316d996..ec86798e653f1 100644 --- a/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/toolchain/OracleOpenJdkToolchainResolver.java +++ b/build-tools-internal/src/main/java/org/elasticsearch/gradle/internal/toolchain/OracleOpenJdkToolchainResolver.java @@ -88,7 +88,7 @@ public String url(String os, String arch, String extension) { List builds = List.of( getBundledJdkBuild(), // 23 early access - new EarlyAccessJdkBuild(JavaLanguageVersion.of(23), "23", "23") + new EarlyAccessJdkBuild(JavaLanguageVersion.of(23), "23", "24") ); private JdkBuild getBundledJdkBuild() { diff --git a/build-tools-internal/version.properties b/build-tools-internal/version.properties index 728f44a365974..1dd9fb95bd17b 100644 --- a/build-tools-internal/version.properties +++ b/build-tools-internal/version.properties @@ -49,7 +49,7 @@ commonsCompress = 1.24.0 reflections = 0.10.2 # benchmark dependencies -jmh = 1.26 +jmh = 1.37 # test dependencies # when updating this version, also update :qa:evil-tests diff --git a/docs/changelog/106520.yaml b/docs/changelog/106520.yaml new file mode 100644 index 0000000000000..c3fe69a4c3dbd --- /dev/null +++ b/docs/changelog/106520.yaml @@ -0,0 +1,6 @@ +pr: 106520 +summary: Updated the transport CA name in Security Auto-Configuration. +area: Security +type: bug +issues: + - 106455 diff --git a/docs/changelog/107047.yaml b/docs/changelog/107047.yaml new file mode 100644 index 0000000000000..89caed6f55074 --- /dev/null +++ b/docs/changelog/107047.yaml @@ -0,0 +1,6 @@ +pr: 107047 +summary: "Search/Mapping: KnnVectorQueryBuilder support for allowUnmappedFields" +area: Search +type: bug +issues: + - 106846 diff --git a/docs/changelog/109501.yaml b/docs/changelog/109501.yaml new file mode 100644 index 0000000000000..6e81f98816cbf --- /dev/null +++ b/docs/changelog/109501.yaml @@ -0,0 +1,14 @@ +pr: 109501 +summary: Reflect latest changes in synthetic source documentation +area: Mapping +type: enhancement +issues: [] +highlight: + title: Synthetic `_source` improvements + body: |- + There are multiple improvements to synthetic `_source` functionality: + + * Synthetic `_source` is now supported for all field types including `nested` and `object`. `object` fields are supported with `enabled` set to `false`. + + * Synthetic `_source` can be enabled together with `ignore_malformed` and `ignore_above` parameters for all field types that support them. + notable: false diff --git a/docs/changelog/109667.yaml b/docs/changelog/109667.yaml new file mode 100644 index 0000000000000..782a1b1cf6c9b --- /dev/null +++ b/docs/changelog/109667.yaml @@ -0,0 +1,5 @@ +pr: 109667 +summary: Inference autoscaling +area: Machine Learning +type: feature +issues: [] diff --git a/docs/changelog/109684.yaml b/docs/changelog/109684.yaml new file mode 100644 index 0000000000000..156f568290cf5 --- /dev/null +++ b/docs/changelog/109684.yaml @@ -0,0 +1,5 @@ +pr: 109684 +summary: Avoid `ModelAssignment` deadlock +area: Machine Learning +type: bug +issues: [] diff --git a/docs/changelog/110021.yaml b/docs/changelog/110021.yaml new file mode 100644 index 0000000000000..51878b960dfd0 --- /dev/null +++ b/docs/changelog/110021.yaml @@ -0,0 +1,6 @@ +pr: 110021 +summary: "[ES|QL] validate `mv_sort` order" +area: ES|QL +type: bug +issues: + - 109910 diff --git a/docs/changelog/110061.yaml b/docs/changelog/110061.yaml new file mode 100644 index 0000000000000..1880a2a197722 --- /dev/null +++ b/docs/changelog/110061.yaml @@ -0,0 +1,6 @@ +pr: 110061 +summary: Avoiding running watch jobs in TickerScheduleTriggerEngine if it is paused +area: Watcher +type: bug +issues: + - 105933 diff --git a/docs/changelog/110399.yaml b/docs/changelog/110399.yaml new file mode 100644 index 0000000000000..9e04e2656809e --- /dev/null +++ b/docs/changelog/110399.yaml @@ -0,0 +1,6 @@ +pr: 110399 +summary: "[Inference API] Prevent inference endpoints from being deleted if they are\ + \ referenced by semantic text" +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/110400.yaml b/docs/changelog/110400.yaml new file mode 100644 index 0000000000000..f2810eba214f1 --- /dev/null +++ b/docs/changelog/110400.yaml @@ -0,0 +1,5 @@ +pr: 110400 +summary: Introduce compute listener +area: ES|QL +type: bug +issues: [] diff --git a/docs/changelog/110476.yaml b/docs/changelog/110476.yaml new file mode 100644 index 0000000000000..bc12b3711a366 --- /dev/null +++ b/docs/changelog/110476.yaml @@ -0,0 +1,7 @@ +pr: 110476 +summary: Fix bug in union-types with type-casting in grouping key of STATS +area: ES|QL +type: bug +issues: + - 109922 + - 110477 diff --git a/docs/changelog/110520.yaml b/docs/changelog/110520.yaml new file mode 100644 index 0000000000000..fba4b84e2279e --- /dev/null +++ b/docs/changelog/110520.yaml @@ -0,0 +1,5 @@ +pr: 110520 +summary: Add protection for OOM during aggregations partial reduction +area: Aggregations +type: enhancement +issues: [] diff --git a/docs/changelog/110527.yaml b/docs/changelog/110527.yaml new file mode 100644 index 0000000000000..3ab19ecaaaa76 --- /dev/null +++ b/docs/changelog/110527.yaml @@ -0,0 +1,5 @@ +pr: 110527 +summary: "ESQL: Add boolean support to Max and Min aggs" +area: ES|QL +type: feature +issues: [] diff --git a/docs/changelog/110540.yaml b/docs/changelog/110540.yaml new file mode 100644 index 0000000000000..5e4994da80704 --- /dev/null +++ b/docs/changelog/110540.yaml @@ -0,0 +1,16 @@ +pr: 110540 +summary: Deprecate using slm privileges to access ilm +area: ILM+SLM +type: deprecation +issues: [] +deprecation: + title: Deprecate using slm privileges to access ilm + area: REST API + details: The `read_slm` privilege can get the ILM status, and + the `manage_slm` privilege can start and stop ILM. Access to these + APIs should be granted using the `read_ilm` and `manage_ilm` privileges + instead. Access to ILM APIs will be removed from SLM privileges in + a future major release, and is now deprecated. + impact: Users that need access to the ILM status API should now + use the `read_ilm` privilege. Users that need to start and stop ILM, + should use the `manage_ilm` privilege. diff --git a/docs/changelog/110554.yaml b/docs/changelog/110554.yaml new file mode 100644 index 0000000000000..8c0b896a4c979 --- /dev/null +++ b/docs/changelog/110554.yaml @@ -0,0 +1,5 @@ +pr: 110554 +summary: Fix `MapperBuilderContext#isDataStream` when used in dynamic mappers +area: "Mapping" +type: bug +issues: [] diff --git a/docs/changelog/110574.yaml b/docs/changelog/110574.yaml new file mode 100644 index 0000000000000..1840838500151 --- /dev/null +++ b/docs/changelog/110574.yaml @@ -0,0 +1,6 @@ +pr: 110574 +summary: "ES|QL: better validation for GROK patterns" +area: ES|QL +type: bug +issues: + - 110533 diff --git a/docs/changelog/110586.yaml b/docs/changelog/110586.yaml new file mode 100644 index 0000000000000..cc2bcb85a2dac --- /dev/null +++ b/docs/changelog/110586.yaml @@ -0,0 +1,5 @@ +pr: 110586 +summary: "ESQL: Fix Max doubles bug with negatives and add tests for Max and Min" +area: ES|QL +type: bug +issues: [] diff --git a/docs/changelog/110651.yaml b/docs/changelog/110651.yaml new file mode 100644 index 0000000000000..c25c63ee0284a --- /dev/null +++ b/docs/changelog/110651.yaml @@ -0,0 +1,5 @@ +pr: 110651 +summary: "Remove `default_field: message` from metrics index templates" +area: Data streams +type: enhancement +issues: [] diff --git a/docs/changelog/110666.yaml b/docs/changelog/110666.yaml new file mode 100644 index 0000000000000..d96f8e2024c81 --- /dev/null +++ b/docs/changelog/110666.yaml @@ -0,0 +1,5 @@ +pr: 110666 +summary: Removing the use of Stream::peek from `GeoIpDownloader::cleanDatabases` +area: Ingest Node +type: bug +issues: [] diff --git a/docs/reference/data-streams/tsds.asciidoc b/docs/reference/data-streams/tsds.asciidoc index 460048d8ccbc9..de89fa1ca3f31 100644 --- a/docs/reference/data-streams/tsds.asciidoc +++ b/docs/reference/data-streams/tsds.asciidoc @@ -53,8 +53,9 @@ shard segments by `_tsid` and `@timestamp`. documents, the document `_id` is a hash of the document's dimensions and `@timestamp`. A TSDS doesn't support custom document `_id` values. + * A TSDS uses <>, and as a result is -subject to a number of <>. +subject to some <> and <> applied to the `_source` field. NOTE: A time series index can contain fields other than dimensions or metrics. diff --git a/docs/reference/esql/functions/aggregation-functions.asciidoc b/docs/reference/esql/functions/aggregation-functions.asciidoc index 11fcd576d336e..82931b84fd44a 100644 --- a/docs/reference/esql/functions/aggregation-functions.asciidoc +++ b/docs/reference/esql/functions/aggregation-functions.asciidoc @@ -8,13 +8,13 @@ The <> command supports these aggregate functions: // tag::agg_list[] -* <> +* <> * <> * <> -* <> +* <> * <> * <> -* <> +* <> * <> * experimental:[] <> * <> @@ -23,16 +23,16 @@ The <> command supports these aggregate functions: * experimental:[] <> // end::agg_list[] -include::avg.asciidoc[] include::count.asciidoc[] include::count-distinct.asciidoc[] -include::max.asciidoc[] include::median.asciidoc[] include::median-absolute-deviation.asciidoc[] -include::min.asciidoc[] include::percentile.asciidoc[] include::st_centroid_agg.asciidoc[] include::sum.asciidoc[] +include::layout/avg.asciidoc[] +include::layout/max.asciidoc[] +include::layout/min.asciidoc[] include::layout/top.asciidoc[] include::values.asciidoc[] include::weighted-avg.asciidoc[] diff --git a/docs/reference/esql/functions/avg.asciidoc b/docs/reference/esql/functions/avg.asciidoc deleted file mode 100644 index 7eadff29f1bfc..0000000000000 --- a/docs/reference/esql/functions/avg.asciidoc +++ /dev/null @@ -1,47 +0,0 @@ -[discrete] -[[esql-agg-avg]] -=== `AVG` - -*Syntax* - -[source,esql] ----- -AVG(expression) ----- - -`expression`:: -Numeric expression. -//If `null`, the function returns `null`. -// TODO: Remove comment when https://github.com/elastic/elasticsearch/issues/104900 is fixed. - -*Description* - -The average of a numeric expression. - -*Supported types* - -The result is always a `double` no matter the input type. - -*Examples* - -[source.merge.styled,esql] ----- -include::{esql-specs}/stats.csv-spec[tag=avg] ----- -[%header.monospaced.styled,format=dsv,separator=|] -|=== -include::{esql-specs}/stats.csv-spec[tag=avg-result] -|=== - -The expression can use inline functions. For example, to calculate the average -over a multivalued column, first use `MV_AVG` to average the multiple values per -row, and use the result with the `AVG` function: - -[source.merge.styled,esql] ----- -include::{esql-specs}/stats.csv-spec[tag=docsStatsAvgNestedExpression] ----- -[%header.monospaced.styled,format=dsv,separator=|] -|=== -include::{esql-specs}/stats.csv-spec[tag=docsStatsAvgNestedExpression-result] -|=== diff --git a/docs/reference/esql/functions/description/avg.asciidoc b/docs/reference/esql/functions/description/avg.asciidoc new file mode 100644 index 0000000000000..545d7e8394e8b --- /dev/null +++ b/docs/reference/esql/functions/description/avg.asciidoc @@ -0,0 +1,5 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Description* + +The average of a numeric field. diff --git a/docs/reference/esql/functions/description/max.asciidoc b/docs/reference/esql/functions/description/max.asciidoc new file mode 100644 index 0000000000000..27a76ed69c3c0 --- /dev/null +++ b/docs/reference/esql/functions/description/max.asciidoc @@ -0,0 +1,5 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Description* + +The maximum value of a field. diff --git a/docs/reference/esql/functions/description/min.asciidoc b/docs/reference/esql/functions/description/min.asciidoc new file mode 100644 index 0000000000000..406125b5761d1 --- /dev/null +++ b/docs/reference/esql/functions/description/min.asciidoc @@ -0,0 +1,5 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Description* + +The minimum value of a field. diff --git a/docs/reference/esql/functions/examples/avg.asciidoc b/docs/reference/esql/functions/examples/avg.asciidoc new file mode 100644 index 0000000000000..b6193ad50ed21 --- /dev/null +++ b/docs/reference/esql/functions/examples/avg.asciidoc @@ -0,0 +1,22 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Examples* + +[source.merge.styled,esql] +---- +include::{esql-specs}/stats.csv-spec[tag=avg] +---- +[%header.monospaced.styled,format=dsv,separator=|] +|=== +include::{esql-specs}/stats.csv-spec[tag=avg-result] +|=== +The expression can use inline functions. For example, to calculate the average over a multivalued column, first use `MV_AVG` to average the multiple values per row, and use the result with the `AVG` function +[source.merge.styled,esql] +---- +include::{esql-specs}/stats.csv-spec[tag=docsStatsAvgNestedExpression] +---- +[%header.monospaced.styled,format=dsv,separator=|] +|=== +include::{esql-specs}/stats.csv-spec[tag=docsStatsAvgNestedExpression-result] +|=== + diff --git a/docs/reference/esql/functions/max.asciidoc b/docs/reference/esql/functions/examples/max.asciidoc similarity index 55% rename from docs/reference/esql/functions/max.asciidoc rename to docs/reference/esql/functions/examples/max.asciidoc index f2e0d0a0205b3..dc57118931ef7 100644 --- a/docs/reference/esql/functions/max.asciidoc +++ b/docs/reference/esql/functions/examples/max.asciidoc @@ -1,24 +1,6 @@ -[discrete] -[[esql-agg-max]] -=== `MAX` +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. -*Syntax* - -[source,esql] ----- -MAX(expression) ----- - -*Parameters* - -`expression`:: -Expression from which to return the maximum value. - -*Description* - -Returns the maximum value of a numeric expression. - -*Example* +*Examples* [source.merge.styled,esql] ---- @@ -28,11 +10,7 @@ include::{esql-specs}/stats.csv-spec[tag=max] |=== include::{esql-specs}/stats.csv-spec[tag=max-result] |=== - -The expression can use inline functions. For example, to calculate the maximum -over an average of a multivalued column, use `MV_AVG` to first average the -multiple values per row, and use the result with the `MAX` function: - +The expression can use inline functions. For example, to calculate the maximum over an average of a multivalued column, use `MV_AVG` to first average the multiple values per row, and use the result with the `MAX` function [source.merge.styled,esql] ---- include::{esql-specs}/stats.csv-spec[tag=docsStatsMaxNestedExpression] @@ -40,4 +18,5 @@ include::{esql-specs}/stats.csv-spec[tag=docsStatsMaxNestedExpression] [%header.monospaced.styled,format=dsv,separator=|] |=== include::{esql-specs}/stats.csv-spec[tag=docsStatsMaxNestedExpression-result] -|=== \ No newline at end of file +|=== + diff --git a/docs/reference/esql/functions/min.asciidoc b/docs/reference/esql/functions/examples/min.asciidoc similarity index 55% rename from docs/reference/esql/functions/min.asciidoc rename to docs/reference/esql/functions/examples/min.asciidoc index 313822818128c..b4088196d750b 100644 --- a/docs/reference/esql/functions/min.asciidoc +++ b/docs/reference/esql/functions/examples/min.asciidoc @@ -1,24 +1,6 @@ -[discrete] -[[esql-agg-min]] -=== `MIN` +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. -*Syntax* - -[source,esql] ----- -MIN(expression) ----- - -*Parameters* - -`expression`:: -Expression from which to return the minimum value. - -*Description* - -Returns the minimum value of a numeric expression. - -*Example* +*Examples* [source.merge.styled,esql] ---- @@ -28,11 +10,7 @@ include::{esql-specs}/stats.csv-spec[tag=min] |=== include::{esql-specs}/stats.csv-spec[tag=min-result] |=== - -The expression can use inline functions. For example, to calculate the minimum -over an average of a multivalued column, use `MV_AVG` to first average the -multiple values per row, and use the result with the `MIN` function: - +The expression can use inline functions. For example, to calculate the minimum over an average of a multivalued column, use `MV_AVG` to first average the multiple values per row, and use the result with the `MIN` function [source.merge.styled,esql] ---- include::{esql-specs}/stats.csv-spec[tag=docsStatsMinNestedExpression] @@ -41,3 +19,4 @@ include::{esql-specs}/stats.csv-spec[tag=docsStatsMinNestedExpression] |=== include::{esql-specs}/stats.csv-spec[tag=docsStatsMinNestedExpression-result] |=== + diff --git a/docs/reference/esql/functions/kibana/definition/avg.json b/docs/reference/esql/functions/kibana/definition/avg.json new file mode 100644 index 0000000000000..eb0be684a468e --- /dev/null +++ b/docs/reference/esql/functions/kibana/definition/avg.json @@ -0,0 +1,48 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", + "type" : "agg", + "name" : "avg", + "description" : "The average of a numeric field.", + "signatures" : [ + { + "params" : [ + { + "name" : "number", + "type" : "double", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "number", + "type" : "integer", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "number", + "type" : "long", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + } + ], + "examples" : [ + "FROM employees\n| STATS AVG(height)", + "FROM employees\n| STATS avg_salary_change = ROUND(AVG(MV_AVG(salary_change)), 10)" + ] +} diff --git a/docs/reference/esql/functions/kibana/definition/max.json b/docs/reference/esql/functions/kibana/definition/max.json new file mode 100644 index 0000000000000..bc7380bd76dd4 --- /dev/null +++ b/docs/reference/esql/functions/kibana/definition/max.json @@ -0,0 +1,72 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", + "type" : "agg", + "name" : "max", + "description" : "The maximum value of a field.", + "signatures" : [ + { + "params" : [ + { + "name" : "field", + "type" : "boolean", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "field", + "type" : "datetime", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "datetime" + }, + { + "params" : [ + { + "name" : "field", + "type" : "double", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "integer", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "integer" + }, + { + "params" : [ + { + "name" : "field", + "type" : "long", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "long" + } + ], + "examples" : [ + "FROM employees\n| STATS MAX(languages)", + "FROM employees\n| STATS max_avg_salary_change = MAX(MV_AVG(salary_change))" + ] +} diff --git a/docs/reference/esql/functions/kibana/definition/min.json b/docs/reference/esql/functions/kibana/definition/min.json new file mode 100644 index 0000000000000..937391bf242ac --- /dev/null +++ b/docs/reference/esql/functions/kibana/definition/min.json @@ -0,0 +1,72 @@ +{ + "comment" : "This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it.", + "type" : "agg", + "name" : "min", + "description" : "The minimum value of a field.", + "signatures" : [ + { + "params" : [ + { + "name" : "field", + "type" : "boolean", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "boolean" + }, + { + "params" : [ + { + "name" : "field", + "type" : "datetime", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "datetime" + }, + { + "params" : [ + { + "name" : "field", + "type" : "double", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "double" + }, + { + "params" : [ + { + "name" : "field", + "type" : "integer", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "integer" + }, + { + "params" : [ + { + "name" : "field", + "type" : "long", + "optional" : false, + "description" : "" + } + ], + "variadic" : false, + "returnType" : "long" + } + ], + "examples" : [ + "FROM employees\n| STATS MIN(languages)", + "FROM employees\n| STATS min_avg_salary_change = MIN(MV_AVG(salary_change))" + ] +} diff --git a/docs/reference/esql/functions/kibana/docs/avg.md b/docs/reference/esql/functions/kibana/docs/avg.md new file mode 100644 index 0000000000000..54006a0556175 --- /dev/null +++ b/docs/reference/esql/functions/kibana/docs/avg.md @@ -0,0 +1,11 @@ + + +### AVG +The average of a numeric field. + +``` +FROM employees +| STATS AVG(height) +``` diff --git a/docs/reference/esql/functions/kibana/docs/max.md b/docs/reference/esql/functions/kibana/docs/max.md new file mode 100644 index 0000000000000..80e88885e7f34 --- /dev/null +++ b/docs/reference/esql/functions/kibana/docs/max.md @@ -0,0 +1,11 @@ + + +### MAX +The maximum value of a field. + +``` +FROM employees +| STATS MAX(languages) +``` diff --git a/docs/reference/esql/functions/kibana/docs/min.md b/docs/reference/esql/functions/kibana/docs/min.md new file mode 100644 index 0000000000000..38d13b97fd344 --- /dev/null +++ b/docs/reference/esql/functions/kibana/docs/min.md @@ -0,0 +1,11 @@ + + +### MIN +The minimum value of a field. + +``` +FROM employees +| STATS MIN(languages) +``` diff --git a/docs/reference/esql/functions/layout/avg.asciidoc b/docs/reference/esql/functions/layout/avg.asciidoc new file mode 100644 index 0000000000000..8292af8e75554 --- /dev/null +++ b/docs/reference/esql/functions/layout/avg.asciidoc @@ -0,0 +1,15 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +[discrete] +[[esql-avg]] +=== `AVG` + +*Syntax* + +[.text-center] +image::esql/functions/signature/avg.svg[Embedded,opts=inline] + +include::../parameters/avg.asciidoc[] +include::../description/avg.asciidoc[] +include::../types/avg.asciidoc[] +include::../examples/avg.asciidoc[] diff --git a/docs/reference/esql/functions/layout/max.asciidoc b/docs/reference/esql/functions/layout/max.asciidoc new file mode 100644 index 0000000000000..a4eb3d99c0d02 --- /dev/null +++ b/docs/reference/esql/functions/layout/max.asciidoc @@ -0,0 +1,15 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +[discrete] +[[esql-max]] +=== `MAX` + +*Syntax* + +[.text-center] +image::esql/functions/signature/max.svg[Embedded,opts=inline] + +include::../parameters/max.asciidoc[] +include::../description/max.asciidoc[] +include::../types/max.asciidoc[] +include::../examples/max.asciidoc[] diff --git a/docs/reference/esql/functions/layout/min.asciidoc b/docs/reference/esql/functions/layout/min.asciidoc new file mode 100644 index 0000000000000..60ad2cc21b561 --- /dev/null +++ b/docs/reference/esql/functions/layout/min.asciidoc @@ -0,0 +1,15 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +[discrete] +[[esql-min]] +=== `MIN` + +*Syntax* + +[.text-center] +image::esql/functions/signature/min.svg[Embedded,opts=inline] + +include::../parameters/min.asciidoc[] +include::../description/min.asciidoc[] +include::../types/min.asciidoc[] +include::../examples/min.asciidoc[] diff --git a/docs/reference/esql/functions/parameters/avg.asciidoc b/docs/reference/esql/functions/parameters/avg.asciidoc new file mode 100644 index 0000000000000..91c56709d182a --- /dev/null +++ b/docs/reference/esql/functions/parameters/avg.asciidoc @@ -0,0 +1,6 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Parameters* + +`number`:: + diff --git a/docs/reference/esql/functions/parameters/max.asciidoc b/docs/reference/esql/functions/parameters/max.asciidoc new file mode 100644 index 0000000000000..8903aa1a472a3 --- /dev/null +++ b/docs/reference/esql/functions/parameters/max.asciidoc @@ -0,0 +1,6 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Parameters* + +`field`:: + diff --git a/docs/reference/esql/functions/parameters/min.asciidoc b/docs/reference/esql/functions/parameters/min.asciidoc new file mode 100644 index 0000000000000..8903aa1a472a3 --- /dev/null +++ b/docs/reference/esql/functions/parameters/min.asciidoc @@ -0,0 +1,6 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Parameters* + +`field`:: + diff --git a/docs/reference/esql/functions/signature/avg.svg b/docs/reference/esql/functions/signature/avg.svg new file mode 100644 index 0000000000000..f325358aff960 --- /dev/null +++ b/docs/reference/esql/functions/signature/avg.svg @@ -0,0 +1 @@ +AVG(number) \ No newline at end of file diff --git a/docs/reference/esql/functions/signature/max.svg b/docs/reference/esql/functions/signature/max.svg new file mode 100644 index 0000000000000..dda43dfbfbba2 --- /dev/null +++ b/docs/reference/esql/functions/signature/max.svg @@ -0,0 +1 @@ +MAX(field) \ No newline at end of file diff --git a/docs/reference/esql/functions/signature/min.svg b/docs/reference/esql/functions/signature/min.svg new file mode 100644 index 0000000000000..e654d3027fee8 --- /dev/null +++ b/docs/reference/esql/functions/signature/min.svg @@ -0,0 +1 @@ +MIN(field) \ No newline at end of file diff --git a/docs/reference/esql/functions/types/avg.asciidoc b/docs/reference/esql/functions/types/avg.asciidoc new file mode 100644 index 0000000000000..273dae4af76c2 --- /dev/null +++ b/docs/reference/esql/functions/types/avg.asciidoc @@ -0,0 +1,11 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Supported types* + +[%header.monospaced.styled,format=dsv,separator=|] +|=== +number | result +double | double +integer | double +long | double +|=== diff --git a/docs/reference/esql/functions/types/max.asciidoc b/docs/reference/esql/functions/types/max.asciidoc new file mode 100644 index 0000000000000..6515c6bfc48d2 --- /dev/null +++ b/docs/reference/esql/functions/types/max.asciidoc @@ -0,0 +1,13 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Supported types* + +[%header.monospaced.styled,format=dsv,separator=|] +|=== +field | result +boolean | boolean +datetime | datetime +double | double +integer | integer +long | long +|=== diff --git a/docs/reference/esql/functions/types/min.asciidoc b/docs/reference/esql/functions/types/min.asciidoc new file mode 100644 index 0000000000000..6515c6bfc48d2 --- /dev/null +++ b/docs/reference/esql/functions/types/min.asciidoc @@ -0,0 +1,13 @@ +// This is generated by ESQL's AbstractFunctionTestCase. Do no edit it. See ../README.md for how to regenerate it. + +*Supported types* + +[%header.monospaced.styled,format=dsv,separator=|] +|=== +field | result +boolean | boolean +datetime | datetime +double | double +integer | integer +long | long +|=== diff --git a/docs/reference/index-modules.asciidoc b/docs/reference/index-modules.asciidoc index 04bebfae2763b..24149afe802a2 100644 --- a/docs/reference/index-modules.asciidoc +++ b/docs/reference/index-modules.asciidoc @@ -81,8 +81,9 @@ breaking change]. If you are updating the compression type, the new one will be applied after segments are merged. Segment merging can be forced using <>. Experiments with indexing log datasets - have shown that `best_compression` gives up to ~18% lower storage usage - compared to `default` while only minimally affecting indexing throughput (~2%). + have shown that `best_compression` gives up to ~18% lower storage usage in + the most ideal scenario compared to `default` while only minimally affecting + indexing throughput (~2%). [[index-mode-setting]] `index.mode`:: + diff --git a/docs/reference/inference/delete-inference.asciidoc b/docs/reference/inference/delete-inference.asciidoc index 2f9d9511e6326..4df72ba672092 100644 --- a/docs/reference/inference/delete-inference.asciidoc +++ b/docs/reference/inference/delete-inference.asciidoc @@ -8,7 +8,7 @@ Deletes an {infer} endpoint. IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in {ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure, Google AI Studio, Google Vertex AI or -Hugging Face. For built-in models and models uploaded though Eland, the {infer} +Hugging Face. For built-in models and models uploaded through Eland, the {infer} APIs offer an alternative way to use and manage trained models. However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the <>. diff --git a/docs/reference/inference/get-inference.asciidoc b/docs/reference/inference/get-inference.asciidoc index 7f4dc1c496837..c3fe841603bcc 100644 --- a/docs/reference/inference/get-inference.asciidoc +++ b/docs/reference/inference/get-inference.asciidoc @@ -8,7 +8,7 @@ Retrieves {infer} endpoint information. IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in {ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure, Google AI Studio, Google Vertex AI or -Hugging Face. For built-in models and models uploaded though Eland, the {infer} +Hugging Face. For built-in models and models uploaded through Eland, the {infer} APIs offer an alternative way to use and manage trained models. However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the <>. diff --git a/docs/reference/inference/inference-apis.asciidoc b/docs/reference/inference/inference-apis.asciidoc index 896cb02a9e699..02a57504da1cf 100644 --- a/docs/reference/inference/inference-apis.asciidoc +++ b/docs/reference/inference/inference-apis.asciidoc @@ -6,7 +6,7 @@ experimental[] IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in {ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure, Google AI Studio or -Hugging Face. For built-in models and models uploaded though Eland, the {infer} +Hugging Face. For built-in models and models uploaded through Eland, the {infer} APIs offer an alternative way to use and manage trained models. However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the <>. diff --git a/docs/reference/inference/post-inference.asciidoc b/docs/reference/inference/post-inference.asciidoc index 3ad23ac3300cc..52131c0b10776 100644 --- a/docs/reference/inference/post-inference.asciidoc +++ b/docs/reference/inference/post-inference.asciidoc @@ -8,7 +8,7 @@ Performs an inference task on an input text by using an {infer} endpoint. IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in {ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Azure, Google AI Studio, Google Vertex AI or -Hugging Face. For built-in models and models uploaded though Eland, the {infer} +Hugging Face. For built-in models and models uploaded through Eland, the {infer} APIs offer an alternative way to use and manage trained models. However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the <>. diff --git a/docs/reference/inference/put-inference.asciidoc b/docs/reference/inference/put-inference.asciidoc index 101c0a24b66b7..656feb54ffe42 100644 --- a/docs/reference/inference/put-inference.asciidoc +++ b/docs/reference/inference/put-inference.asciidoc @@ -8,7 +8,7 @@ Creates an {infer} endpoint to perform an {infer} task. IMPORTANT: The {infer} APIs enable you to use certain services, such as built-in {ml} models (ELSER, E5), models uploaded through Eland, Cohere, OpenAI, Mistral, Azure OpenAI, Google AI Studio, Google Vertex AI or Hugging Face. -For built-in models and models uploaded though Eland, the {infer} APIs offer an alternative way to use and manage trained models. +For built-in models and models uploaded through Eland, the {infer} APIs offer an alternative way to use and manage trained models. However, if you do not plan to use the {infer} APIs to use these models or if you want to use non-NLP models, use the <>. diff --git a/docs/reference/mapping/fields/source-field.asciidoc b/docs/reference/mapping/fields/source-field.asciidoc index ec824e421e015..903b301ab1a96 100644 --- a/docs/reference/mapping/fields/source-field.asciidoc +++ b/docs/reference/mapping/fields/source-field.asciidoc @@ -6,11 +6,11 @@ at index time. The `_source` field itself is not indexed (and thus is not searchable), but it is stored so that it can be returned when executing _fetch_ requests, like <> or <>. -If disk usage is important to you then have a look at -<> which shrinks disk usage at the cost of -only supporting a subset of mappings and slower fetches or (not recommended) -<> which also shrinks disk -usage but disables many features. +If disk usage is important to you, then consider the following options: + +- Using <>, which reconstructs source content at the time of retrieval instead of storing it on disk. This shrinks disk usage, at the cost of slower access to `_source` in <> and <> queries. +- <>. This shrinks disk +usage but disables features that rely on `_source`. include::synthetic-source.asciidoc[] @@ -43,7 +43,7 @@ available then a number of features are not supported: * The <>, <>, and <> APIs. -* In the {kib} link:{kibana-ref}/discover.html[Discover] application, field data will not be displayed. +* In the {kib} link:{kibana-ref}/discover.html[Discover] application, field data will not be displayed. * On the fly <>. diff --git a/docs/reference/mapping/fields/synthetic-source.asciidoc b/docs/reference/mapping/fields/synthetic-source.asciidoc index a0e7aed177a9c..ccea38cf602da 100644 --- a/docs/reference/mapping/fields/synthetic-source.asciidoc +++ b/docs/reference/mapping/fields/synthetic-source.asciidoc @@ -28,45 +28,22 @@ PUT idx While this on the fly reconstruction is *generally* slower than saving the source documents verbatim and loading them at query time, it saves a lot of storage -space. +space. Additional latency can be avoided by not loading `_source` field in queries when it is not needed. + +[[synthetic-source-fields]] +===== Supported fields +Synthetic `_source` is supported by all field types. Depending on implementation details, field types have different properties when used with synthetic `_source`. + +<> construct synthetic `_source` using existing data, most commonly <> and <>. For these field types, no additional space is needed to store the contents of `_source` field. Due to the storage layout of <>, the generated `_source` field undergoes <> compared to original document. + +For all other field types, the original value of the field is stored as is, in the same way as the `_source` field in non-synthetic mode. In this case there are no modifications and field data in `_source` is the same as in the original document. Similarly, malformed values of fields that use <> or <> need to be stored as is. This approach is less storage efficient since data needed for `_source` reconstruction is stored in addition to other data required to index the field (like `doc_values`). [[synthetic-source-restrictions]] ===== Synthetic `_source` restrictions -There are a couple of restrictions to be aware of: +Synthetic `_source` cannot be used together with field mappings that use <>. -* When you retrieve synthetic `_source` content it undergoes minor -<> compared to the original JSON. -* Synthetic `_source` can be used with indices that contain only these field -types: - -** <> -** {plugins}/mapper-annotated-text-usage.html#annotated-text-synthetic-source[`annotated-text`] -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> -** <> +Some field types have additional restrictions. These restrictions are documented in the **synthetic `_source`** section of the field type's <>. [[synthetic-source-modifications]] ===== Synthetic `_source` modifications @@ -178,4 +155,40 @@ that ordering. [[synthetic-source-modifications-ranges]] ====== Representation of ranges -Range field vales (e.g. `long_range`) are always represented as inclusive on both sides with bounds adjusted accordingly. See <>. +Range field values (e.g. `long_range`) are always represented as inclusive on both sides with bounds adjusted accordingly. See <>. + +[[synthetic-source-precision-loss-for-point-types]] +====== Reduced precision of `geo_point` values +Values of `geo_point` fields are represented in synthetic `_source` with reduced precision. See <>. + + +[[synthetic-source-fields-native-list]] +===== Field types that support synthetic source with no storage overhead +The following field types support synthetic source using data from <> or <>, and require no additional storage space to construct the `_source` field. + +NOTE: If you enable the <> or <> settings, then additional storage is required to store ignored field values for these types. + +** <> +** {plugins}/mapper-annotated-text-usage.html#annotated-text-synthetic-source[`annotated-text`] +** <> +** <> +** <> +** <> +** <> +** <> +** <> +** <> +** <> +** <> +** <> +** <> +** <> +** <> +** <> +** <> +** <> +** <> +** <> +** <> +** <> +** <> diff --git a/docs/reference/modules/network.asciidoc b/docs/reference/modules/network.asciidoc index 55c236ce43574..593aa79ded4d9 100644 --- a/docs/reference/modules/network.asciidoc +++ b/docs/reference/modules/network.asciidoc @@ -153,6 +153,8 @@ The only requirements are that each node must be: cluster, and by any remote clusters that will discover it using <>. +Each node must have its own distinct publish address. + If you specify the transport publish address using a hostname then {es} will resolve this hostname to an IP address once during startup, and other nodes will use the resulting IP address instead of resolving the name again diff --git a/docs/reference/release-notes/8.13.0.asciidoc b/docs/reference/release-notes/8.13.0.asciidoc index dba4fdbe5f67e..4bb2913f07be7 100644 --- a/docs/reference/release-notes/8.13.0.asciidoc +++ b/docs/reference/release-notes/8.13.0.asciidoc @@ -21,6 +21,13 @@ This affects clusters running version 8.10 or later, with an active downsampling https://www.elastic.co/guide/en/elasticsearch/reference/current/downsampling-ilm.html[configuration] or a configuration that was activated at some point since upgrading to version 8.10 or later. +* When upgrading clusters from version 8.12.2 or earlier, if your cluster contains non-master-eligible nodes, +information about the new functionality of these upgraded nodes may not be registered properly with the master node. +This can lead to some new functionality added since 8.13.0 not being accessible on the upgraded cluster. +If your cluster is running on ECK 2.12.1 and above, this may cause problems with finalizing the upgrade. +To resolve this issue, perform a rolling restart on the non-master-eligible nodes once all Elasticsearch nodes +are upgraded. + [[breaking-8.13.0]] [float] === Breaking changes diff --git a/docs/reference/release-notes/8.13.1.asciidoc b/docs/reference/release-notes/8.13.1.asciidoc index 7b3dbff74cc6e..572f9fe1172a9 100644 --- a/docs/reference/release-notes/8.13.1.asciidoc +++ b/docs/reference/release-notes/8.13.1.asciidoc @@ -3,6 +3,16 @@ Also see <>. +[[known-issues-8.13.1]] +[float] +=== Known issues +* When upgrading clusters from version 8.12.2 or earlier, if your cluster contains non-master-eligible nodes, +information about the new functionality of these upgraded nodes may not be registered properly with the master node. +This can lead to some new functionality added since 8.13.0 not being accessible on the upgraded cluster. +If your cluster is running on ECK 2.12.1 and above, this may cause problems with finalizing the upgrade. +To resolve this issue, perform a rolling restart on the non-master-eligible nodes once all Elasticsearch nodes +are upgraded. + [[bug-8.13.1]] [float] diff --git a/docs/reference/release-notes/8.13.2.asciidoc b/docs/reference/release-notes/8.13.2.asciidoc index 514118f5ea575..20ae7abbb5769 100644 --- a/docs/reference/release-notes/8.13.2.asciidoc +++ b/docs/reference/release-notes/8.13.2.asciidoc @@ -3,6 +3,16 @@ Also see <>. +[[known-issues-8.13.2]] +[float] +=== Known issues +* When upgrading clusters from version 8.12.2 or earlier, if your cluster contains non-master-eligible nodes, +information about the new functionality of these upgraded nodes may not be registered properly with the master node. +This can lead to some new functionality added since 8.13.0 not being accessible on the upgraded cluster. +If your cluster is running on ECK 2.12.1 and above, this may cause problems with finalizing the upgrade. +To resolve this issue, perform a rolling restart on the non-master-eligible nodes once all Elasticsearch nodes +are upgraded. + [[bug-8.13.2]] [float] diff --git a/docs/reference/release-notes/8.13.3.asciidoc b/docs/reference/release-notes/8.13.3.asciidoc index 9aee0dd815f6d..ea51bd6f9b743 100644 --- a/docs/reference/release-notes/8.13.3.asciidoc +++ b/docs/reference/release-notes/8.13.3.asciidoc @@ -10,6 +10,16 @@ Also see <>. SQL:: * Limit how much space some string functions can use {es-pull}107333[#107333] +[[known-issues-8.13.3]] +[float] +=== Known issues +* When upgrading clusters from version 8.12.2 or earlier, if your cluster contains non-master-eligible nodes, +information about the new functionality of these upgraded nodes may not be registered properly with the master node. +This can lead to some new functionality added since 8.13.0 not being accessible on the upgraded cluster. +If your cluster is running on ECK 2.12.1 and above, this may cause problems with finalizing the upgrade. +To resolve this issue, perform a rolling restart on the non-master-eligible nodes once all Elasticsearch nodes +are upgraded. + [[bug-8.13.3]] [float] === Bug fixes diff --git a/docs/reference/release-notes/8.13.4.asciidoc b/docs/reference/release-notes/8.13.4.asciidoc index bf3f2f497d8fc..b60c9f485bb31 100644 --- a/docs/reference/release-notes/8.13.4.asciidoc +++ b/docs/reference/release-notes/8.13.4.asciidoc @@ -3,6 +3,16 @@ Also see <>. +[[known-issues-8.13.4]] +[float] +=== Known issues +* When upgrading clusters from version 8.12.2 or earlier, if your cluster contains non-master-eligible nodes, +information about the new functionality of these upgraded nodes may not be registered properly with the master node. +This can lead to some new functionality added since 8.13.0 not being accessible on the upgraded cluster. +If your cluster is running on ECK 2.12.1 and above, this may cause problems with finalizing the upgrade. +To resolve this issue, perform a rolling restart on the non-master-eligible nodes once all Elasticsearch nodes +are upgraded. + [[bug-8.13.4]] [float] === Bug fixes diff --git a/docs/reference/release-notes/8.14.0.asciidoc b/docs/reference/release-notes/8.14.0.asciidoc index 42f2f86a123ed..5b92c49ced70a 100644 --- a/docs/reference/release-notes/8.14.0.asciidoc +++ b/docs/reference/release-notes/8.14.0.asciidoc @@ -12,6 +12,16 @@ Security:: * Apply stricter Document Level Security (DLS) rules for the validate query API with the rewrite parameter {es-pull}105709[#105709] * Apply stricter Document Level Security (DLS) rules for terms aggregations when min_doc_count is set to 0 {es-pull}105714[#105714] +[[known-issues-8.14.0]] +[float] +=== Known issues +* When upgrading clusters from version 8.12.2 or earlier, if your cluster contains non-master-eligible nodes, +information about the new functionality of these upgraded nodes may not be registered properly with the master node. +This can lead to some new functionality added since 8.13.0 not being accessible on the upgraded cluster. +If your cluster is running on ECK 2.12.1 and above, this may cause problems with finalizing the upgrade. +To resolve this issue, perform a rolling restart on the non-master-eligible nodes once all Elasticsearch nodes +are upgraded. + [[bug-8.14.0]] [float] === Bug fixes diff --git a/docs/reference/release-notes/8.14.1.asciidoc b/docs/reference/release-notes/8.14.1.asciidoc index f161c7d08099c..1cab442eb9ac1 100644 --- a/docs/reference/release-notes/8.14.1.asciidoc +++ b/docs/reference/release-notes/8.14.1.asciidoc @@ -4,6 +4,16 @@ Also see <>. +[[known-issues-8.14.1]] +[float] +=== Known issues +* When upgrading clusters from version 8.12.2 or earlier, if your cluster contains non-master-eligible nodes, +information about the new functionality of these upgraded nodes may not be registered properly with the master node. +This can lead to some new functionality added since 8.13.0 not being accessible on the upgraded cluster. +If your cluster is running on ECK 2.12.1 and above, this may cause problems with finalizing the upgrade. +To resolve this issue, perform a rolling restart on the non-master-eligible nodes once all Elasticsearch nodes +are upgraded. + [[bug-8.14.1]] [float] === Bug fixes diff --git a/docs/reference/release-notes/8.14.2.asciidoc b/docs/reference/release-notes/8.14.2.asciidoc index 2bb374451b2ac..d94067f030c61 100644 --- a/docs/reference/release-notes/8.14.2.asciidoc +++ b/docs/reference/release-notes/8.14.2.asciidoc @@ -1,10 +1,18 @@ [[release-notes-8.14.2]] == {es} version 8.14.2 -coming[8.14.2] - Also see <>. +[[known-issues-8.14.2]] +[float] +=== Known issues +* When upgrading clusters from version 8.12.2 or earlier, if your cluster contains non-master-eligible nodes, +information about the new functionality of these upgraded nodes may not be registered properly with the master node. +This can lead to some new functionality added since 8.13.0 not being accessible on the upgraded cluster. +If your cluster is running on ECK 2.12.1 and above, this may cause problems with finalizing the upgrade. +To resolve this issue, perform a rolling restart on the non-master-eligible nodes once all Elasticsearch nodes +are upgraded. + [[bug-8.14.2]] [float] === Bug fixes @@ -35,4 +43,4 @@ Ranking:: Search:: * Add hexstring support byte painless scorers {es-pull}109492[#109492] -* Fix automatic tracking of collapse with `docvalue_fields` {es-pull}110103[#110103] \ No newline at end of file +* Fix automatic tracking of collapse with `docvalue_fields` {es-pull}110103[#110103] diff --git a/docs/reference/release-notes/highlights.asciidoc b/docs/reference/release-notes/highlights.asciidoc index e70892ef25928..0ed01ff422700 100644 --- a/docs/reference/release-notes/highlights.asciidoc +++ b/docs/reference/release-notes/highlights.asciidoc @@ -30,13 +30,158 @@ Other versions: endif::[] -// The notable-highlights tag marks entries that -// should be featured in the Stack Installation and Upgrade Guide: // tag::notable-highlights[] -// [discrete] -// === Heading -// -// Description. + +[discrete] +[[stored_fields_are_compressed_with_zstandard_instead_of_lz4_deflate]] +=== Stored fields are now compressed with ZStandard instead of LZ4/DEFLATE +Stored fields are now compressed by splitting documents into blocks, which +are then compressed independently with ZStandard. `index.codec: default` +(default) uses blocks of at most 14kB or 128 documents compressed with level +0, while `index.codec: best_compression` uses blocks of at most 240kB or +2048 documents compressed at level 3. On most datasets that we tested +against, this yielded storage improvements in the order of 10%, slightly +faster indexing and similar retrieval latencies. + +{es-pull}103374[#103374] + +[discrete] +[[stricter_failure_handling_in_multi_repo_get_snapshots_request_handling]] +=== Stricter failure handling in multi-repo get-snapshots request handling +If a multi-repo get-snapshots request encounters a failure in one of the +targeted repositories then earlier versions of Elasticsearch would proceed +as if the faulty repository did not exist, except for a per-repository +failure report in a separate section of the response body. This makes it +impossible to paginate the results properly in the presence of failures. In +versions 8.15.0 and later this API's failure handling behaviour has been +made stricter, reporting an overall failure if any targeted repository's +contents cannot be listed. + +{es-pull}107191[#107191] + +[discrete] +[[add_new_int4_quantization_to_dense_vector]] +=== Add new int4 quantization to dense_vector +New int4 (half-byte) scalar quantization support via two knew index types: `int4_hnsw` and `int4_flat`. +This gives an 8x reduction from `float32` with some accuracy loss. In addition to less memory required, this +improves query and merge speed significantly when compared to raw vectors. + +{es-pull}109317[#109317] + +[discrete] +[[mark_query_rules_as_ga]] +=== Mark Query Rules as GA +This PR marks query rules as Generally Available. All APIs are no longer +in tech preview. + +{es-pull}110004[#110004] + +[discrete] +[[adds_new_bit_element_type_for_dense_vectors]] +=== Adds new `bit` `element_type` for `dense_vectors` +This adds `bit` vector support by adding `element_type: bit` for +vectors. This new element type works for indexed and non-indexed +vectors. Additionally, it works with `hnsw` and `flat` index types. No +quantization based codec works with this element type, this is +consistent with `byte` vectors. + +`bit` vectors accept up to `32768` dimensions in size and expect vectors +that are being indexed to be encoded either as a hexidecimal string or a +`byte[]` array where each element of the `byte` array represents `8` +bits of the vector. + +`bit` vectors support script usage and regular query usage. When +indexed, all comparisons done are `xor` and `popcount` summations (aka, +hamming distance), and the scores are transformed and normalized given +the vector dimensions. + +For scripts, `l1norm` is the same as `hamming` distance and `l2norm` is +`sqrt(l1norm)`. `dotProduct` and `cosineSimilarity` are not supported. + +Note, the dimensions expected by this element_type are always to be +divisible by `8`, and the `byte[]` vectors provided for index must be +have size `dim/8` size, where each byte element represents `8` bits of +the vectors. + +{es-pull}110059[#110059] + +[discrete] +[[redact_processor_generally_available]] +=== The Redact processor is Generally Available +The Redact processor uses the Grok rules engine to obscure text in the input document matching the given Grok patterns. The Redact processor was initially released as Technical Preview in `8.7.0`, and is now released as Generally Available. + +{es-pull}110395[#110395] + // end::notable-highlights[] +[discrete] +[[new_custom_parser_for_iso_8601_datetimes]] +=== New custom parser for ISO-8601 datetimes +This introduces a new custom parser for ISO-8601 datetimes, for the `iso8601`, `strict_date_optional_time`, and +`strict_date_optional_time_nanos` built-in date formats. This provides a performance improvement over the +default Java date-time parsing. Whilst it maintains much of the same behaviour, +the new parser does not accept nonsensical date-time strings that have multiple fractional seconds fields +or multiple timezone specifiers. If the new parser fails to parse a string, it will then use the previous parser +to parse it. If a large proportion of the input data consists of these invalid strings, this may cause +a small performance degradation. If you wish to force the use of the old parsers regardless, +set the JVM property `es.datetime.java_time_parsers=true` on all ES nodes. + +{es-pull}106486[#106486] + +[discrete] +[[new_custom_parser_for_more_iso_8601_date_formats]] +=== New custom parser for more ISO-8601 date formats +Following on from #106486, this extends the custom ISO-8601 datetime parser to cover the `strict_year`, +`strict_year_month`, `strict_date_time`, `strict_date_time_no_millis`, `strict_date_hour_minute_second`, +`strict_date_hour_minute_second_millis`, and `strict_date_hour_minute_second_fraction` date formats. +As before, the parser will use the existing java.time parser if there are parsing issues, and the +`es.datetime.java_time_parsers=true` JVM property will force the use of the old parsers regardless. + +{es-pull}108606[#108606] + +[discrete] +[[preview_support_for_connection_type_domain_isp_databases_in_geoip_processor]] +=== Preview: Support for the 'Connection Type, 'Domain', and 'ISP' databases in the geoip processor +As a Technical Preview, the {ref}/geoip-processor.html[`geoip`] processor can now use the commercial +https://dev.maxmind.com/geoip/docs/databases/connection-type[GeoIP2 'Connection Type'], +https://dev.maxmind.com/geoip/docs/databases/domain[GeoIP2 'Domain'], +and +https://dev.maxmind.com/geoip/docs/databases/isp[GeoIP2 'ISP'] +databases from MaxMind. + +{es-pull}108683[#108683] + +[discrete] +[[update_elasticsearch_to_lucene_9_11]] +=== Update Elasticsearch to Lucene 9.11 +Elasticsearch is now updated using the latest Lucene version 9.11. +Here are the full release notes: +But, here are some particular highlights: +- Usage of MADVISE for better memory management: https://github.com/apache/lucene/pull/13196 +- Use RWLock to access LRUQueryCache to reduce contention: https://github.com/apache/lucene/pull/13306 +- Speedup multi-segment HNSW graph search for nested kNN queries: https://github.com/apache/lucene/pull/13121 +- Add a MemorySegment Vector scorer - for scoring without copying on-heap vectors: https://github.com/apache/lucene/pull/13339 + +{es-pull}109219[#109219] + +[discrete] +[[synthetic_source_improvements]] +=== Synthetic `_source` improvements +There are multiple improvements to synthetic `_source` functionality: + +* Synthetic `_source` is now supported for all field types including `nested` and `object`. `object` fields are supported with `enabled` set to `false`. + +* Synthetic `_source` can be enabled together with `ignore_malformed` and `ignore_above` parameters for all field types that support them. + +{es-pull}109501[#109501] + +[discrete] +[[index_sorting_on_indexes_with_nested_fields]] +=== Index sorting on indexes with nested fields +Index sorting is now supported for indexes with mappings containing nested objects. +The index sort spec (as specified by `index.sort.field`) can't contain any nested +fields, still. + +{es-pull}110251[#110251] + diff --git a/docs/reference/rest-api/security.asciidoc b/docs/reference/rest-api/security.asciidoc index 04cd838c45600..82cf38e52bd80 100644 --- a/docs/reference/rest-api/security.asciidoc +++ b/docs/reference/rest-api/security.asciidoc @@ -50,6 +50,7 @@ Use the following APIs to add, remove, update, and retrieve roles in the native * <> * <> * <> +* <> [discrete] [[security-token-apis]] @@ -192,6 +193,7 @@ include::security/get-app-privileges.asciidoc[] include::security/get-builtin-privileges.asciidoc[] include::security/get-role-mappings.asciidoc[] include::security/get-roles.asciidoc[] +include::security/query-role.asciidoc[] include::security/get-service-accounts.asciidoc[] include::security/get-service-credentials.asciidoc[] include::security/get-settings.asciidoc[] diff --git a/docs/reference/rest-api/security/get-roles.asciidoc b/docs/reference/rest-api/security/get-roles.asciidoc index 3eb5a735194c6..3cc2f95c6ea7e 100644 --- a/docs/reference/rest-api/security/get-roles.asciidoc +++ b/docs/reference/rest-api/security/get-roles.asciidoc @@ -38,7 +38,10 @@ API cannot retrieve roles that are defined in roles files. ==== {api-response-body-title} A successful call returns an array of roles with the JSON representation of the -role. +role. The returned role format is a simple extension of the <> format, +only adding an extra field `transient_metadata.enabled`. +This field is `false` in case the role is automatically disabled, for example when the license +level does not allow some permissions that the role grants. [[security-api-get-role-response-codes]] ==== {api-response-codes-title} diff --git a/docs/reference/rest-api/security/query-role.asciidoc b/docs/reference/rest-api/security/query-role.asciidoc new file mode 100644 index 0000000000000..937bd263140fc --- /dev/null +++ b/docs/reference/rest-api/security/query-role.asciidoc @@ -0,0 +1,283 @@ +[role="xpack"] +[[security-api-query-role]] +=== Query Role API + +++++ +Query Role +++++ + +Retrieves roles with <> in a <> fashion. + +[[security-api-query-role-request]] +==== {api-request-title} + +`GET /_security/_query/role` + +`POST /_security/_query/role` + +[[security-api-query-role-prereqs]] +==== {api-prereq-title} + +* To use this API, you must have at least the `read_security` cluster privilege. + +[[security-api-query-role-desc]] +==== {api-description-title} + +The role management APIs are generally the preferred way to manage roles, rather than using +<>. +The query roles API does not retrieve roles that are defined in roles files, nor <> ones. +You can optionally filter the results with a query. Also, the results can be paginated and sorted. + +[[security-api-query-role-request-body]] +==== {api-request-body-title} + +You can specify the following parameters in the request body: + +`query`:: +(Optional, string) A <> to filter which roles to return. +The query supports a subset of query types, including +<>, <>, +<>, <>, +<>, <>, +<>, <>, +<>, <>, +and <>. ++ +You can query the following values associated with a role. ++ +.Valid values for `query` +[%collapsible%open] +==== +`name`:: +(keyword) The <> of the role. + +`description`:: +(text) The <> of the role. + +`metadata`:: +(flattened) Metadata field associated with the <>, such as `metadata.app_tag`. +Note that metadata is internally indexed as a <> field type. +This means that all sub-fields act like `keyword` fields when querying and sorting. +It also implies that it is not possible to refer to a subset of metadata fields using wildcard patterns, +e.g. `metadata.field*`, even for query types that support field name patterns. +Lastly, all the metadata fields can be searched together when simply mentioning the +`metadata` field (i.e. not followed by any dot and sub-field name). + +`applications`:: +The list of <> that the role grants. + +`application`::: +(keyword) The name of the application associated to the privileges and resources. + +`privileges`::: +(keyword) The names of the privileges that the role grants. + +`resources`::: +(keyword) The resources to which the privileges apply. + +==== + +include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=from] ++ +By default, you cannot page through more than 10,000 hits using the `from` and +`size` parameters. To page through more hits, use the +<> parameter. + +`size`:: +(Optional, integer) The number of hits to return. Must not be negative and defaults to `10`. ++ +By default, you cannot page through more than 10,000 hits using the `from` and +`size` parameters. To page through more hits, use the +<> parameter. + +`sort`:: +(Optional, object) <>. You can sort on `username`, `roles` or `enabled`. +In addition, sort can also be applied to the `_doc` field to sort by index order. + +`search_after`:: +(Optional, array) <> definition. + + +[[security-api-query-role-response-body]] +==== {api-response-body-title} + +This API returns the following top level fields: + +`total`:: +The total number of roles found. + +`count`:: +The number of roles returned in the response. + +`roles`:: +A list of roles that match the query. +The returned role format is an extension of the <> format. +It adds the `transient_metadata.enabled` and the `_sort` fields. +`transient_metadata.enabled` is set to `false` in case the role is automatically disabled, +for example when the role grants privileges that are not allowed by the installed license. +`_sort` is present when the search query sorts on some field. +It contains the array of values that have been used for sorting. + +[[security-api-query-role-example]] +==== {api-examples-title} + +The following request lists all roles, sorted by the role name: + +[source,console] +---- +POST /_security/_query/role +{ + "sort": ["name"] +} +---- +// TEST[setup:admin_role,user_role] + +A successful call returns a JSON structure that contains the information +retrieved for one or more roles: + +[source,console-result] +---- +{ + "total": 2, + "count": 2, + "roles": [ <1> + { + "name" : "my_admin_role", + "cluster" : [ + "all" + ], + "indices" : [ + { + "names" : [ + "index1", + "index2" + ], + "privileges" : [ + "all" + ], + "field_security" : { + "grant" : [ + "title", + "body" + ] + }, + "allow_restricted_indices" : false + } + ], + "applications" : [ ], + "run_as" : [ + "other_user" + ], + "metadata" : { + "version" : 1 + }, + "transient_metadata" : { + "enabled" : true + }, + "description" : "Grants full access to all management features within the cluster.", + "_sort" : [ + "my_admin_role" + ] + }, + { + "name" : "my_user_role", + "cluster" : [ ], + "indices" : [ + { + "names" : [ + "index1", + "index2" + ], + "privileges" : [ + "all" + ], + "field_security" : { + "grant" : [ + "title", + "body" + ] + }, + "allow_restricted_indices" : false + } + ], + "applications" : [ ], + "run_as" : [ ], + "metadata" : { + "version" : 1 + }, + "transient_metadata" : { + "enabled" : true + }, + "description" : "Grants user access to some indicies.", + "_sort" : [ + "my_user_role" + ] + } + ] +} +---- +// TEST[continued] + +<1> The list of roles that were retrieved for this request + +Similarly, the following request can be used to query only the user access role, +given its description: + +[source,console] +---- +POST /_security/_query/role +{ + "query": { + "match": { + "description": { + "query": "user access" + } + } + }, + "size": 1 <1> +} +---- +// TEST[continued] + +<1> Return only the best matching role + +[source,console-result] +---- +{ + "total": 2, + "count": 1, + "roles": [ + { + "name" : "my_user_role", + "cluster" : [ ], + "indices" : [ + { + "names" : [ + "index1", + "index2" + ], + "privileges" : [ + "all" + ], + "field_security" : { + "grant" : [ + "title", + "body" + ] + }, + "allow_restricted_indices" : false + } + ], + "applications" : [ ], + "run_as" : [ ], + "metadata" : { + "version" : 1 + }, + "transient_metadata" : { + "enabled" : true + }, + "description" : "Grants user access to some indicies." + } + ] +} +---- diff --git a/docs/reference/rest-api/security/query-user.asciidoc b/docs/reference/rest-api/security/query-user.asciidoc index 952e0f40f2a3a..23852f0f2eed7 100644 --- a/docs/reference/rest-api/security/query-user.asciidoc +++ b/docs/reference/rest-api/security/query-user.asciidoc @@ -66,13 +66,6 @@ The email of the user. Specifies whether the user is enabled. ==== -[[security-api-query-user-query-params]] -==== {api-query-parms-title} - -`with_profile_uid`:: -(Optional, boolean) Determines whether to retrieve the <> `uid`, -if exists, for the users. Defaults to `false`. - include::{es-ref-dir}/rest-api/common-parms.asciidoc[tag=from] + By default, you cannot page through more than 10,000 hits using the `from` and @@ -93,6 +86,12 @@ In addition, sort can also be applied to the `_doc` field to sort by index order `search_after`:: (Optional, array) <> definition. +[[security-api-query-user-query-params]] +==== {api-query-parms-title} + +`with_profile_uid`:: +(Optional, boolean) Determines whether to retrieve the <> `uid`, +if exists, for the users. Defaults to `false`. [[security-api-query-user-response-body]] ==== {api-response-body-title} @@ -191,7 +190,7 @@ Use the user information retrieve the user with a query: [source,console] ---- -GET /_security/_query/user +POST /_security/_query/user { "query": { "prefix": { @@ -231,7 +230,7 @@ To retrieve the user `profile_uid` as part of the response: [source,console] -------------------------------------------------- -GET /_security/_query/user?with_profile_uid=true +POST /_security/_query/user?with_profile_uid=true { "query": { "prefix": { @@ -272,7 +271,7 @@ Use a `bool` query to issue complex logical conditions and use [source,js] ---- -GET /_security/_query/user +POST /_security/_query/user { "query": { "bool": { diff --git a/docs/reference/search/multi-search-template-api.asciidoc b/docs/reference/search/multi-search-template-api.asciidoc index c8eea52a6fd9b..b1c9518b1f2bc 100644 --- a/docs/reference/search/multi-search-template-api.asciidoc +++ b/docs/reference/search/multi-search-template-api.asciidoc @@ -22,9 +22,6 @@ PUT _scripts/my-search-template }, "from": "{{from}}", "size": "{{size}}" - }, - "params": { - "query_string": "My query string" } } } diff --git a/docs/reference/search/render-search-template-api.asciidoc b/docs/reference/search/render-search-template-api.asciidoc index 1f259dddf6879..0c782f26068e6 100644 --- a/docs/reference/search/render-search-template-api.asciidoc +++ b/docs/reference/search/render-search-template-api.asciidoc @@ -22,9 +22,6 @@ PUT _scripts/my-search-template }, "from": "{{from}}", "size": "{{size}}" - }, - "params": { - "query_string": "My query string" } } } diff --git a/docs/reference/search/search-template-api.asciidoc b/docs/reference/search/search-template-api.asciidoc index 038396e558607..c60b5281c05e5 100644 --- a/docs/reference/search/search-template-api.asciidoc +++ b/docs/reference/search/search-template-api.asciidoc @@ -21,9 +21,6 @@ PUT _scripts/my-search-template }, "from": "{{from}}", "size": "{{size}}" - }, - "params": { - "query_string": "My query string" } } } diff --git a/docs/reference/search/search-your-data/search-template.asciidoc b/docs/reference/search/search-your-data/search-template.asciidoc index 7a7f09f4a37a7..489a03c0a6a2a 100644 --- a/docs/reference/search/search-your-data/search-template.asciidoc +++ b/docs/reference/search/search-your-data/search-template.asciidoc @@ -42,9 +42,6 @@ PUT _scripts/my-search-template }, "from": "{{from}}", "size": "{{size}}" - }, - "params": { - "query_string": "My query string" } } } diff --git a/docs/reference/security/authorization/managing-roles.asciidoc b/docs/reference/security/authorization/managing-roles.asciidoc index 253aa33822234..535d70cbc5e9c 100644 --- a/docs/reference/security/authorization/managing-roles.asciidoc +++ b/docs/reference/security/authorization/managing-roles.asciidoc @@ -13,7 +13,9 @@ A role is defined by the following JSON structure: "indices": [ ... ], <4> "applications": [ ... ], <5> "remote_indices": [ ... ], <6> - "remote_cluster": [ ... ] <7> + "remote_cluster": [ ... ], <7> + "metadata": { ... }, <8> + "description": "..." <9> } ----- // NOTCONSOLE @@ -40,6 +42,16 @@ A role is defined by the following JSON structure: <>. This field is optional (missing `remote_cluster` privileges effectively means no additional cluster permissions for any API key based remote clusters). +<8> Metadata field associated with the role, such as `metadata.app_tag`. + Metadata is internally indexed as a <> field type. + This means that all sub-fields act like `keyword` fields when querying and sorting. + Metadata values can be simple values, but also lists and maps. + This field is optional. +<9> A string value with the description text of the role. + The maximum length of it is `1000` chars. + The field is internally indexed as a <> field type + (with default values for all parameters). + This field is optional. [[valid-role-name]] NOTE: Role names must be at least 1 and no more than 507 characters. They can diff --git a/docs/reference/security/authorization/privileges.asciidoc b/docs/reference/security/authorization/privileges.asciidoc index cc44c97a08129..44897baa8cb4a 100644 --- a/docs/reference/security/authorization/privileges.asciidoc +++ b/docs/reference/security/authorization/privileges.asciidoc @@ -2,7 +2,7 @@ === Security privileges :frontmatter-description: A list of privileges that can be assigned to user roles. :frontmatter-tags-products: [elasticsearch] -:frontmatter-tags-content-type: [reference] +:frontmatter-tags-content-type: [reference] :frontmatter-tags-user-goals: [secure] This section lists the privileges that you can assign to a role. @@ -198,6 +198,10 @@ All {slm} ({slm-init}) actions, including creating and updating policies and starting and stopping {slm-init}. + This privilege is not available in {serverless-full}. ++ +deprecated:[8.15] Also grants the permission to start and stop {Ilm}, using +the {ref}/ilm-start.html[ILM start] and {ref}/ilm-stop.html[ILM stop] APIs. +In a future major release, this privilege will not grant any {Ilm} permissions. `manage_token`:: All security-related operations on tokens that are generated by the {es} Token @@ -285,6 +289,10 @@ All read-only {slm-init} actions, such as getting policies and checking the {slm-init} status. + This privilege is not available in {serverless-full}. ++ +deprecated:[8.15] Also grants the permission to get the {Ilm} status, using +the {ref}/ilm-get-status.html[ILM get status API]. In a future major release, +this privilege will not grant any {Ilm} permissions. `read_security`:: All read-only security-related operations, such as getting users, user profiles, diff --git a/gradle/verification-metadata.xml b/gradle/verification-metadata.xml index cd408ba75aa10..5e26d96c4ca17 100644 --- a/gradle/verification-metadata.xml +++ b/gradle/verification-metadata.xml @@ -84,6 +84,11 @@ + + + + + @@ -1694,16 +1699,16 @@ - - - - - + + + + + @@ -3832,14 +3837,14 @@ - - - + + + - - - + + + diff --git a/libs/core/src/test/java/org/elasticsearch/core/AbstractRefCountedTests.java b/libs/core/src/test/java/org/elasticsearch/core/AbstractRefCountedTests.java index 9610bae32a775..74dcd19248834 100644 --- a/libs/core/src/test/java/org/elasticsearch/core/AbstractRefCountedTests.java +++ b/libs/core/src/test/java/org/elasticsearch/core/AbstractRefCountedTests.java @@ -9,7 +9,6 @@ import org.elasticsearch.test.ESTestCase; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicBoolean; import static org.hamcrest.Matchers.equalTo; @@ -62,32 +61,22 @@ public void testRefCount() { public void testMultiThreaded() throws InterruptedException { final AbstractRefCounted counted = createRefCounted(); - final Thread[] threads = new Thread[randomIntBetween(2, 5)]; - final CountDownLatch latch = new CountDownLatch(1); - for (int i = 0; i < threads.length; i++) { - threads[i] = new Thread(() -> { - try { - latch.await(); - for (int j = 0; j < 10000; j++) { - assertTrue(counted.hasReferences()); - if (randomBoolean()) { - counted.incRef(); - } else { - assertTrue(counted.tryIncRef()); - } - assertTrue(counted.hasReferences()); - counted.decRef(); + startInParallel(randomIntBetween(2, 5), i -> { + try { + for (int j = 0; j < 10000; j++) { + assertTrue(counted.hasReferences()); + if (randomBoolean()) { + counted.incRef(); + } else { + assertTrue(counted.tryIncRef()); } - } catch (Exception e) { - throw new AssertionError(e); + assertTrue(counted.hasReferences()); + counted.decRef(); } - }); - threads[i].start(); - } - latch.countDown(); - for (Thread thread : threads) { - thread.join(); - } + } catch (Exception e) { + throw new AssertionError(e); + } + }); counted.decRef(); assertFalse(counted.hasReferences()); assertThat( diff --git a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaKernel32Library.java b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaKernel32Library.java index 0bfdf959f7b58..2c7ec70f36eb3 100644 --- a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaKernel32Library.java +++ b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaKernel32Library.java @@ -13,6 +13,7 @@ import com.sun.jna.NativeLong; import com.sun.jna.Pointer; import com.sun.jna.Structure; +import com.sun.jna.Structure.ByReference; import com.sun.jna.WString; import com.sun.jna.win32.StdCallLibrary; @@ -98,6 +99,38 @@ public long Type() { } } + /** + * Basic limit information for a job object + * + * https://msdn.microsoft.com/en-us/library/windows/desktop/ms684147%28v=vs.85%29.aspx + */ + public static class JnaJobObjectBasicLimitInformation extends Structure implements ByReference, JobObjectBasicLimitInformation { + public byte[] _ignore1 = new byte[16]; + public int LimitFlags; + public byte[] _ignore2 = new byte[20]; + public int ActiveProcessLimit; + public byte[] _ignore3 = new byte[20]; + + public JnaJobObjectBasicLimitInformation() { + super(8); + } + + @Override + protected List getFieldOrder() { + return List.of("_ignore1", "LimitFlags", "_ignore2", "ActiveProcessLimit", "_ignore3"); + } + + @Override + public void setLimitFlags(int v) { + LimitFlags = v; + } + + @Override + public void setActiveProcessLimit(int v) { + ActiveProcessLimit = v; + } + } + /** * JNA adaptation of {@link ConsoleCtrlHandler} */ @@ -128,6 +161,20 @@ private interface NativeFunctions extends StdCallLibrary { int GetShortPathNameW(WString lpszLongPath, char[] lpszShortPath, int cchBuffer); boolean SetConsoleCtrlHandler(StdCallLibrary.StdCallCallback handler, boolean add); + + Pointer CreateJobObjectW(Pointer jobAttributes, String name); + + boolean AssignProcessToJobObject(Pointer job, Pointer process); + + boolean QueryInformationJobObject( + Pointer job, + int infoClass, + JnaJobObjectBasicLimitInformation info, + int infoLength, + Pointer returnLength + ); + + boolean SetInformationJobObject(Pointer job, int infoClass, JnaJobObjectBasicLimitInformation info, int infoLength); } private final NativeFunctions functions; @@ -197,4 +244,42 @@ public boolean SetConsoleCtrlHandler(ConsoleCtrlHandler handler, boolean add) { consoleCtrlHandlerCallback = new NativeHandlerCallback(handler); return functions.SetConsoleCtrlHandler(consoleCtrlHandlerCallback, true); } + + @Override + public Handle CreateJobObjectW() { + return new JnaHandle(functions.CreateJobObjectW(null, null)); + } + + @Override + public boolean AssignProcessToJobObject(Handle job, Handle process) { + assert job instanceof JnaHandle; + assert process instanceof JnaHandle; + var jnaJob = (JnaHandle) job; + var jnaProcess = (JnaHandle) process; + return functions.AssignProcessToJobObject(jnaJob.pointer, jnaProcess.pointer); + } + + @Override + public JobObjectBasicLimitInformation newJobObjectBasicLimitInformation() { + return new JnaJobObjectBasicLimitInformation(); + } + + @Override + public boolean QueryInformationJobObject(Handle job, int infoClass, JobObjectBasicLimitInformation info) { + assert job instanceof JnaHandle; + assert info instanceof JnaJobObjectBasicLimitInformation; + var jnaJob = (JnaHandle) job; + var jnaInfo = (JnaJobObjectBasicLimitInformation) info; + var ret = functions.QueryInformationJobObject(jnaJob.pointer, infoClass, jnaInfo, jnaInfo.size(), null); + return ret; + } + + @Override + public boolean SetInformationJobObject(Handle job, int infoClass, JobObjectBasicLimitInformation info) { + assert job instanceof JnaHandle; + assert info instanceof JnaJobObjectBasicLimitInformation; + var jnaJob = (JnaHandle) job; + var jnaInfo = (JnaJobObjectBasicLimitInformation) info; + return functions.SetInformationJobObject(jnaJob.pointer, infoClass, jnaInfo, jnaInfo.size()); + } } diff --git a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaLinuxCLibrary.java b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaLinuxCLibrary.java new file mode 100644 index 0000000000000..742c666d59c23 --- /dev/null +++ b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaLinuxCLibrary.java @@ -0,0 +1,94 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.nativeaccess.jna; + +import com.sun.jna.Library; +import com.sun.jna.Memory; +import com.sun.jna.Native; +import com.sun.jna.NativeLong; +import com.sun.jna.Pointer; +import com.sun.jna.Structure; + +import org.elasticsearch.nativeaccess.lib.LinuxCLibrary; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +class JnaLinuxCLibrary implements LinuxCLibrary { + + @Structure.FieldOrder({ "len", "filter" }) + public static final class JnaSockFProg extends Structure implements Structure.ByReference, SockFProg { + public short len; // number of filters + public Pointer filter; // filters + + JnaSockFProg(SockFilter filters[]) { + len = (short) filters.length; + // serialize struct sock_filter * explicitly, its less confusing than the JNA magic we would need + Memory filter = new Memory(len * 8); + ByteBuffer bbuf = filter.getByteBuffer(0, len * 8); + bbuf.order(ByteOrder.nativeOrder()); // little endian + for (SockFilter f : filters) { + bbuf.putShort(f.code()); + bbuf.put(f.jt()); + bbuf.put(f.jf()); + bbuf.putInt(f.k()); + } + this.filter = filter; + } + + @Override + public long address() { + return Pointer.nativeValue(getPointer()); + } + } + + private interface NativeFunctions extends Library { + + /** + * maps to prctl(2) + */ + int prctl(int option, NativeLong arg2, NativeLong arg3, NativeLong arg4, NativeLong arg5); + + /** + * used to call seccomp(2), its too new... + * this is the only way, DON'T use it on some other architecture unless you know wtf you are doing + */ + NativeLong syscall(NativeLong number, Object... args); + } + + private final NativeFunctions functions; + + JnaLinuxCLibrary() { + try { + this.functions = Native.load("c", NativeFunctions.class); + } catch (UnsatisfiedLinkError e) { + throw new UnsupportedOperationException( + "seccomp unavailable: could not link methods. requires kernel 3.5+ " + + "with CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER compiled in" + ); + } + } + + @Override + public SockFProg newSockFProg(SockFilter[] filters) { + var prog = new JnaSockFProg(filters); + prog.write(); + return prog; + } + + @Override + public int prctl(int option, long arg2, long arg3, long arg4, long arg5) { + return functions.prctl(option, new NativeLong(arg2), new NativeLong(arg3), new NativeLong(arg4), new NativeLong(arg5)); + } + + @Override + public long syscall(long number, int operation, int flags, long address) { + return functions.syscall(new NativeLong(number), operation, flags, address).longValue(); + } +} diff --git a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaMacCLibrary.java b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaMacCLibrary.java new file mode 100644 index 0000000000000..f416cf862b417 --- /dev/null +++ b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaMacCLibrary.java @@ -0,0 +1,59 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.nativeaccess.jna; + +import com.sun.jna.Library; +import com.sun.jna.Native; +import com.sun.jna.Pointer; +import com.sun.jna.ptr.PointerByReference; + +import org.elasticsearch.nativeaccess.lib.MacCLibrary; + +class JnaMacCLibrary implements MacCLibrary { + static class JnaErrorReference implements ErrorReference { + final PointerByReference ref = new PointerByReference(); + + @Override + public String toString() { + return ref.getValue().getString(0); + } + } + + private interface NativeFunctions extends Library { + int sandbox_init(String profile, long flags, PointerByReference errorbuf); + + void sandbox_free_error(Pointer errorbuf); + } + + private final NativeFunctions functions; + + JnaMacCLibrary() { + this.functions = Native.load("c", NativeFunctions.class); + } + + @Override + public ErrorReference newErrorReference() { + return new JnaErrorReference(); + } + + @Override + public int sandbox_init(String profile, long flags, ErrorReference errorbuf) { + assert errorbuf instanceof JnaErrorReference; + var jnaErrorbuf = (JnaErrorReference) errorbuf; + return functions.sandbox_init(profile, flags, jnaErrorbuf.ref); + } + + @Override + public void sandbox_free_error(ErrorReference errorbuf) { + assert errorbuf instanceof JnaErrorReference; + var jnaErrorbuf = (JnaErrorReference) errorbuf; + functions.sandbox_free_error(jnaErrorbuf.ref.getValue()); + } + +} diff --git a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaNativeLibraryProvider.java b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaNativeLibraryProvider.java index 9d34b1ba617e8..454581ae70b51 100644 --- a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaNativeLibraryProvider.java +++ b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaNativeLibraryProvider.java @@ -10,6 +10,8 @@ import org.elasticsearch.nativeaccess.lib.JavaLibrary; import org.elasticsearch.nativeaccess.lib.Kernel32Library; +import org.elasticsearch.nativeaccess.lib.LinuxCLibrary; +import org.elasticsearch.nativeaccess.lib.MacCLibrary; import org.elasticsearch.nativeaccess.lib.NativeLibrary; import org.elasticsearch.nativeaccess.lib.NativeLibraryProvider; import org.elasticsearch.nativeaccess.lib.PosixCLibrary; @@ -30,6 +32,10 @@ public JnaNativeLibraryProvider() { JnaJavaLibrary::new, PosixCLibrary.class, JnaPosixCLibrary::new, + LinuxCLibrary.class, + JnaLinuxCLibrary::new, + MacCLibrary.class, + JnaMacCLibrary::new, Kernel32Library.class, JnaKernel32Library::new, SystemdLibrary.class, diff --git a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaPosixCLibrary.java b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaPosixCLibrary.java index 7e8e4f23ab034..03a7b9c0869be 100644 --- a/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaPosixCLibrary.java +++ b/libs/native/jna/src/main/java/org/elasticsearch/nativeaccess/jna/JnaPosixCLibrary.java @@ -39,6 +39,50 @@ public long rlim_cur() { public long rlim_max() { return rlim_max.longValue(); } + + @Override + public void rlim_cur(long v) { + rlim_cur.setValue(v); + } + + @Override + public void rlim_max(long v) { + rlim_max.setValue(v); + } + } + + public static class JnaFStore extends Structure implements Structure.ByReference, FStore { + + public int fst_flags = 0; + public int fst_posmode = 0; + public NativeLong fst_offset = new NativeLong(0); + public NativeLong fst_length = new NativeLong(0); + public NativeLong fst_bytesalloc = new NativeLong(0); + + @Override + public void set_flags(int flags) { + this.fst_flags = flags; + } + + @Override + public void set_posmode(int posmode) { + this.fst_posmode = posmode; + } + + @Override + public void set_offset(long offset) { + fst_offset.setValue(offset); + } + + @Override + public void set_length(long length) { + fst_length.setValue(length); + } + + @Override + public long bytesalloc() { + return fst_bytesalloc.longValue(); + } } private interface NativeFunctions extends Library { @@ -46,8 +90,12 @@ private interface NativeFunctions extends Library { int getrlimit(int resource, JnaRLimit rlimit); + int setrlimit(int resource, JnaRLimit rlimit); + int mlockall(int flags); + int fcntl(int fd, int cmd, JnaFStore fst); + String strerror(int errno); } @@ -74,11 +122,30 @@ public int getrlimit(int resource, RLimit rlimit) { return functions.getrlimit(resource, jnaRlimit); } + @Override + public int setrlimit(int resource, RLimit rlimit) { + assert rlimit instanceof JnaRLimit; + var jnaRlimit = (JnaRLimit) rlimit; + return functions.setrlimit(resource, jnaRlimit); + } + @Override public int mlockall(int flags) { return functions.mlockall(flags); } + @Override + public FStore newFStore() { + return new JnaFStore(); + } + + @Override + public int fcntl(int fd, int cmd, FStore fst) { + assert fst instanceof JnaFStore; + var jnaFst = (JnaFStore) fst; + return functions.fcntl(fd, cmd, jnaFst); + } + @Override public String strerror(int errno) { return functions.strerror(errno); diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/AbstractNativeAccess.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/AbstractNativeAccess.java index 80a18a2bc8aa0..c10f57a900ff7 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/AbstractNativeAccess.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/AbstractNativeAccess.java @@ -22,6 +22,7 @@ abstract class AbstractNativeAccess implements NativeAccess { private final JavaLibrary javaLib; private final Zstd zstd; protected boolean isMemoryLocked = false; + protected ExecSandboxState execSandboxState = ExecSandboxState.NONE; protected AbstractNativeAccess(String name, NativeLibraryProvider libraryProvider) { this.name = name; @@ -53,4 +54,9 @@ public CloseableByteBuffer newBuffer(int len) { public boolean isMemoryLocked() { return isMemoryLocked; } + + @Override + public ExecSandboxState getExecSandboxState() { + return execSandboxState; + } } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/LinuxNativeAccess.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/LinuxNativeAccess.java index 7948dad1df4ad..c50e639c94d27 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/LinuxNativeAccess.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/LinuxNativeAccess.java @@ -8,15 +8,88 @@ package org.elasticsearch.nativeaccess; +import org.elasticsearch.nativeaccess.lib.LinuxCLibrary; +import org.elasticsearch.nativeaccess.lib.LinuxCLibrary.SockFProg; +import org.elasticsearch.nativeaccess.lib.LinuxCLibrary.SockFilter; import org.elasticsearch.nativeaccess.lib.NativeLibraryProvider; import org.elasticsearch.nativeaccess.lib.SystemdLibrary; +import java.util.Map; + class LinuxNativeAccess extends PosixNativeAccess { - Systemd systemd; + /** the preferred method is seccomp(2), since we can apply to all threads of the process */ + static final int SECCOMP_SET_MODE_FILTER = 1; // since Linux 3.17 + static final int SECCOMP_FILTER_FLAG_TSYNC = 1; // since Linux 3.17 + + /** otherwise, we can use prctl(2), which will at least protect ES application threads */ + static final int PR_GET_NO_NEW_PRIVS = 39; // since Linux 3.5 + static final int PR_SET_NO_NEW_PRIVS = 38; // since Linux 3.5 + static final int PR_GET_SECCOMP = 21; // since Linux 2.6.23 + static final int PR_SET_SECCOMP = 22; // since Linux 2.6.23 + static final long SECCOMP_MODE_FILTER = 2; // since Linux Linux 3.5 + + // BPF "macros" and constants + static final int BPF_LD = 0x00; + static final int BPF_W = 0x00; + static final int BPF_ABS = 0x20; + static final int BPF_JMP = 0x05; + static final int BPF_JEQ = 0x10; + static final int BPF_JGE = 0x30; + static final int BPF_JGT = 0x20; + static final int BPF_RET = 0x06; + static final int BPF_K = 0x00; + + static SockFilter BPF_STMT(int code, int k) { + return new SockFilter((short) code, (byte) 0, (byte) 0, k); + } + + static SockFilter BPF_JUMP(int code, int k, int jt, int jf) { + return new SockFilter((short) code, (byte) jt, (byte) jf, k); + } + + static final int SECCOMP_RET_ERRNO = 0x00050000; + static final int SECCOMP_RET_DATA = 0x0000FFFF; + static final int SECCOMP_RET_ALLOW = 0x7FFF0000; + + // some errno constants for error checking/handling + static final int EACCES = 0x0D; + static final int EFAULT = 0x0E; + static final int EINVAL = 0x16; + static final int ENOSYS = 0x26; + + // offsets that our BPF checks + // check with offsetof() when adding a new arch, move to Arch if different. + static final int SECCOMP_DATA_NR_OFFSET = 0x00; + static final int SECCOMP_DATA_ARCH_OFFSET = 0x04; + + record Arch( + int audit, // AUDIT_ARCH_XXX constant from linux/audit.h + int limit, // syscall limit (necessary for blacklisting on amd64, to ban 32-bit syscalls) + int fork, // __NR_fork + int vfork, // __NR_vfork + int execve, // __NR_execve + int execveat, // __NR_execveat + int seccomp // __NR_seccomp + ) {} + + /** supported architectures for seccomp keyed by os.arch */ + private static final Map ARCHITECTURES; + static { + ARCHITECTURES = Map.of( + "amd64", + new Arch(0xC000003E, 0x3FFFFFFF, 57, 58, 59, 322, 317), + "aarch64", + new Arch(0xC00000B7, 0xFFFFFFFF, 1079, 1071, 221, 281, 277) + ); + } + + private final LinuxCLibrary linuxLibc; + private final Systemd systemd; LinuxNativeAccess(NativeLibraryProvider libraryProvider) { super("Linux", libraryProvider, new PosixConstants(-1L, 9, 1, 8)); + this.linuxLibc = libraryProvider.getLibrary(LinuxCLibrary.class); this.systemd = new Systemd(libraryProvider.getLibrary(SystemdLibrary.class)); } @@ -46,4 +119,197 @@ protected void logMemoryLimitInstructions() { \t{} hard memlock unlimited""", user, user, user); logger.warn("If you are logged in interactively, you will have to re-login for the new limits to take effect."); } + + /** + * Installs exec system call filtering for Linux. + *

+ * On Linux exec system call filtering currently supports amd64 and aarch64 architectures. + * It requires Linux kernel 3.5 or above, and {@code CONFIG_SECCOMP} and {@code CONFIG_SECCOMP_FILTER} + * compiled into the kernel. + *

+ * On Linux BPF Filters are installed using either {@code seccomp(2)} (3.17+) or {@code prctl(2)} (3.5+). {@code seccomp(2)} + * is preferred, as it allows filters to be applied to any existing threads in the process, and one motivation + * here is to protect against bugs in the JVM. Otherwise, code will fall back to the {@code prctl(2)} method + * which will at least protect elasticsearch application threads. + *

+ * Linux BPF filters will return {@code EACCES} (Access Denied) for the following system calls: + *

    + *
  • {@code execve}
  • + *
  • {@code fork}
  • + *
  • {@code vfork}
  • + *
  • {@code execveat}
  • + *
+ * @see + * * http://www.kernel.org/doc/Documentation/prctl/seccomp_filter.txt + */ + @Override + public void tryInstallExecSandbox() { + // first be defensive: we can give nice errors this way, at the very least. + // also, some of these security features get backported to old versions, checking kernel version here is a big no-no! + String archId = System.getProperty("os.arch"); + final Arch arch = ARCHITECTURES.get(archId); + if (arch == null) { + throw new UnsupportedOperationException("seccomp unavailable: '" + archId + "' architecture unsupported"); + } + + // try to check system calls really are who they claim + // you never know (e.g. https://chromium.googlesource.com/chromium/src.git/+/master/sandbox/linux/seccomp-bpf/sandbox_bpf.cc#57) + final int bogusArg = 0xf7a46a5c; + + // test seccomp(BOGUS) + long ret = linuxLibc.syscall(arch.seccomp, bogusArg, 0, 0); + if (ret != -1) { + throw new UnsupportedOperationException("seccomp unavailable: seccomp(BOGUS_OPERATION) returned " + ret); + } else { + int errno = libc.errno(); + switch (errno) { + case ENOSYS: + break; // ok + case EINVAL: + break; // ok + default: + throw new UnsupportedOperationException("seccomp(BOGUS_OPERATION): " + libc.strerror(errno)); + } + } + + // test seccomp(VALID, BOGUS) + ret = linuxLibc.syscall(arch.seccomp, SECCOMP_SET_MODE_FILTER, bogusArg, 0); + if (ret != -1) { + throw new UnsupportedOperationException("seccomp unavailable: seccomp(SECCOMP_SET_MODE_FILTER, BOGUS_FLAG) returned " + ret); + } else { + int errno = libc.errno(); + switch (errno) { + case ENOSYS: + break; // ok + case EINVAL: + break; // ok + default: + throw new UnsupportedOperationException("seccomp(SECCOMP_SET_MODE_FILTER, BOGUS_FLAG): " + libc.strerror(errno)); + } + } + + // test prctl(BOGUS) + ret = linuxLibc.prctl(bogusArg, 0, 0, 0, 0); + if (ret != -1) { + throw new UnsupportedOperationException("seccomp unavailable: prctl(BOGUS_OPTION) returned " + ret); + } else { + int errno = libc.errno(); + switch (errno) { + case ENOSYS: + break; // ok + case EINVAL: + break; // ok + default: + throw new UnsupportedOperationException("prctl(BOGUS_OPTION): " + libc.strerror(errno)); + } + } + + // now just normal defensive checks + + // check for GET_NO_NEW_PRIVS + switch (linuxLibc.prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0)) { + case 0: + break; // not yet set + case 1: + break; // already set by caller + default: + int errno = libc.errno(); + if (errno == EINVAL) { + // friendly error, this will be the typical case for an old kernel + throw new UnsupportedOperationException( + "seccomp unavailable: requires kernel 3.5+ with" + " CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER compiled in" + ); + } else { + throw new UnsupportedOperationException("prctl(PR_GET_NO_NEW_PRIVS): " + libc.strerror(errno)); + } + } + // check for SECCOMP + switch (linuxLibc.prctl(PR_GET_SECCOMP, 0, 0, 0, 0)) { + case 0: + break; // not yet set + case 2: + break; // already in filter mode by caller + default: + int errno = libc.errno(); + if (errno == EINVAL) { + throw new UnsupportedOperationException( + "seccomp unavailable: CONFIG_SECCOMP not compiled into kernel," + + " CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER are needed" + ); + } else { + throw new UnsupportedOperationException("prctl(PR_GET_SECCOMP): " + libc.strerror(errno)); + } + } + // check for SECCOMP_MODE_FILTER + if (linuxLibc.prctl(PR_SET_SECCOMP, SECCOMP_MODE_FILTER, 0, 0, 0) != 0) { + int errno = libc.errno(); + switch (errno) { + case EFAULT: + break; // available + case EINVAL: + throw new UnsupportedOperationException( + "seccomp unavailable: CONFIG_SECCOMP_FILTER not" + + " compiled into kernel, CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER are needed" + ); + default: + throw new UnsupportedOperationException("prctl(PR_SET_SECCOMP): " + libc.strerror(errno)); + } + } + + // ok, now set PR_SET_NO_NEW_PRIVS, needed to be able to set a seccomp filter as ordinary user + if (linuxLibc.prctl(PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0) != 0) { + throw new UnsupportedOperationException("prctl(PR_SET_NO_NEW_PRIVS): " + libc.strerror(libc.errno())); + } + + // check it worked + if (linuxLibc.prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0) != 1) { + throw new UnsupportedOperationException( + "seccomp filter did not really succeed: prctl(PR_GET_NO_NEW_PRIVS): " + libc.strerror(libc.errno()) + ); + } + + // BPF installed to check arch, limit, then syscall. + // See https://www.kernel.org/doc/Documentation/prctl/seccomp_filter.txt for details. + SockFilter insns[] = { + /* 1 */ BPF_STMT(BPF_LD + BPF_W + BPF_ABS, SECCOMP_DATA_ARCH_OFFSET), // + /* 2 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.audit, 0, 7), // if (arch != audit) goto fail; + /* 3 */ BPF_STMT(BPF_LD + BPF_W + BPF_ABS, SECCOMP_DATA_NR_OFFSET), // + /* 4 */ BPF_JUMP(BPF_JMP + BPF_JGT + BPF_K, arch.limit, 5, 0), // if (syscall > LIMIT) goto fail; + /* 5 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.fork, 4, 0), // if (syscall == FORK) goto fail; + /* 6 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.vfork, 3, 0), // if (syscall == VFORK) goto fail; + /* 7 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.execve, 2, 0), // if (syscall == EXECVE) goto fail; + /* 8 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.execveat, 1, 0), // if (syscall == EXECVEAT) goto fail; + /* 9 */ BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_ALLOW), // pass: return OK; + /* 10 */ BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_ERRNO | (EACCES & SECCOMP_RET_DATA)), // fail: return EACCES; + }; + // seccomp takes a long, so we pass it one explicitly to keep the JNA simple + SockFProg prog = linuxLibc.newSockFProg(insns); + + int method = 1; + // install filter, if this works, after this there is no going back! + // first try it with seccomp(SECCOMP_SET_MODE_FILTER), falling back to prctl() + if (linuxLibc.syscall(arch.seccomp, SECCOMP_SET_MODE_FILTER, SECCOMP_FILTER_FLAG_TSYNC, prog.address()) != 0) { + method = 0; + int errno1 = libc.errno(); + if (logger.isDebugEnabled()) { + logger.debug("seccomp(SECCOMP_SET_MODE_FILTER): {}, falling back to prctl(PR_SET_SECCOMP)...", libc.strerror(errno1)); + } + if (linuxLibc.prctl(PR_SET_SECCOMP, SECCOMP_MODE_FILTER, prog.address(), 0, 0) != 0) { + int errno2 = libc.errno(); + throw new UnsupportedOperationException( + "seccomp(SECCOMP_SET_MODE_FILTER): " + libc.strerror(errno1) + ", prctl(PR_SET_SECCOMP): " + libc.strerror(errno2) + ); + } + } + + // now check that the filter was really installed, we should be in filter mode. + if (linuxLibc.prctl(PR_GET_SECCOMP, 0, 0, 0, 0) != 2) { + throw new UnsupportedOperationException( + "seccomp filter installation did not really succeed. seccomp(PR_GET_SECCOMP): " + libc.strerror(libc.errno()) + ); + } + + logger.debug("Linux seccomp filter installation successful, threads: [{}]", method == 1 ? "all" : "app"); + execSandboxState = method == 1 ? ExecSandboxState.ALL_THREADS : ExecSandboxState.EXISTING_THREADS; + } } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/MacNativeAccess.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/MacNativeAccess.java index 0388c66d3962f..c53b7ba6ac2f0 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/MacNativeAccess.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/MacNativeAccess.java @@ -8,12 +8,30 @@ package org.elasticsearch.nativeaccess; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.nativeaccess.lib.MacCLibrary; import org.elasticsearch.nativeaccess.lib.NativeLibraryProvider; +import org.elasticsearch.nativeaccess.lib.PosixCLibrary.RLimit; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.Collections; class MacNativeAccess extends PosixNativeAccess { + /** The only supported flag... */ + static final int SANDBOX_NAMED = 1; + /** Allow everything except process fork and execution */ + static final String SANDBOX_RULES = "(version 1) (allow default) (deny process-fork) (deny process-exec)"; + + private final MacCLibrary macLibc; + MacNativeAccess(NativeLibraryProvider libraryProvider) { super("MacOS", libraryProvider, new PosixConstants(9223372036854775807L, 5, 1, 6)); + this.macLibc = libraryProvider.getLibrary(MacCLibrary.class); } @Override @@ -25,4 +43,69 @@ protected long getMaxThreads() { protected void logMemoryLimitInstructions() { // we don't have instructions for macos } + + /** + * Installs exec system call filtering on MacOS. + *

+ * Two different methods of filtering are used. Since MacOS is BSD based, process creation + * is first restricted with {@code setrlimit(RLIMIT_NPROC)}. + *

+ * Additionally, on Mac OS X Leopard or above, a custom {@code sandbox(7)} ("Seatbelt") profile is installed that + * denies the following rules: + *

    + *
  • {@code process-fork}
  • + *
  • {@code process-exec}
  • + *
+ * @see + * * https://reverse.put.as/wp-content/uploads/2011/06/The-Apple-Sandbox-BHDC2011-Paper.pdf + */ + @Override + public void tryInstallExecSandbox() { + initBsdSandbox(); + initMacSandbox(); + execSandboxState = ExecSandboxState.ALL_THREADS; + } + + @SuppressForbidden(reason = "Java tmp dir is ok") + private static Path createTempRulesFile() throws IOException { + return Files.createTempFile("es", "sb"); + } + + private void initMacSandbox() { + // write rules to a temporary file, which will be passed to sandbox_init() + Path rules; + try { + rules = createTempRulesFile(); + Files.write(rules, Collections.singleton(SANDBOX_RULES)); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + + try { + var errorRef = macLibc.newErrorReference(); + int ret = macLibc.sandbox_init(rules.toAbsolutePath().toString(), SANDBOX_NAMED, errorRef); + // if sandbox_init() fails, add the message from the OS (e.g. syntax error) and free the buffer + if (ret != 0) { + RuntimeException e = new UnsupportedOperationException("sandbox_init(): " + errorRef.toString()); + macLibc.sandbox_free_error(errorRef); + throw e; + } + logger.debug("OS X seatbelt initialization successful"); + } finally { + IOUtils.deleteFilesIgnoringExceptions(rules); + } + } + + private void initBsdSandbox() { + RLimit limit = libc.newRLimit(); + limit.rlim_cur(0); + limit.rlim_max(0); + // not a standard limit, means something different on linux, etc! + final int RLIMIT_NPROC = 7; + if (libc.setrlimit(RLIMIT_NPROC, limit) != 0) { + throw new UnsupportedOperationException("RLIMIT_NPROC unavailable: " + libc.strerror(libc.errno())); + } + + logger.debug("BSD RLIMIT_NPROC initialization successful"); + } } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/NativeAccess.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/NativeAccess.java index 7f91d0425af47..61935ac93c5a3 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/NativeAccess.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/NativeAccess.java @@ -44,6 +44,16 @@ static NativeAccess instance() { */ boolean isMemoryLocked(); + /** + * Attempts to install a system call filter to block process execution. + */ + void tryInstallExecSandbox(); + + /** + * Return whether installing the exec system call filters was successful, and to what degree. + */ + ExecSandboxState getExecSandboxState(); + Systemd systemd(); /** @@ -71,4 +81,16 @@ default WindowsFunctions getWindowsFunctions() { * @return the buffer */ CloseableByteBuffer newBuffer(int len); + + /** + * Possible stats for execution filtering. + */ + enum ExecSandboxState { + /** No execution filtering */ + NONE, + /** Exec is blocked for threads that were already created */ + EXISTING_THREADS, + /** Exec is blocked for all current and future threads */ + ALL_THREADS + } } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/NoopNativeAccess.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/NoopNativeAccess.java index c0eed4a9ce09b..fc186cb03b0d9 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/NoopNativeAccess.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/NoopNativeAccess.java @@ -41,6 +41,16 @@ public boolean isMemoryLocked() { return false; } + @Override + public void tryInstallExecSandbox() { + logger.warn("Cannot install system call filter because native access is not available"); + } + + @Override + public ExecSandboxState getExecSandboxState() { + return ExecSandboxState.NONE; + } + @Override public Systemd systemd() { logger.warn("Cannot get systemd access because native access is not available"); diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/WindowsNativeAccess.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/WindowsNativeAccess.java index 843cc73fbed02..a9ccd15330595 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/WindowsNativeAccess.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/WindowsNativeAccess.java @@ -27,6 +27,16 @@ class WindowsNativeAccess extends AbstractNativeAccess { public static final int PAGE_GUARD = 0x0100; public static final int MEM_COMMIT = 0x1000; + /** + * Constant for JOBOBJECT_BASIC_LIMIT_INFORMATION in Query/Set InformationJobObject + */ + private static final int JOBOBJECT_BASIC_LIMIT_INFORMATION_CLASS = 2; + + /** + * Constant for LimitFlags, indicating a process limit has been set + */ + private static final int JOB_OBJECT_LIMIT_ACTIVE_PROCESS = 8; + private final Kernel32Library kernel; private final WindowsFunctions windowsFunctions; @@ -68,6 +78,47 @@ public void tryLockMemory() { // note: no need to close the process handle because GetCurrentProcess returns a pseudo handle } + /** + * Install exec system call filtering on Windows. + *

+ * Process creation is restricted with {@code SetInformationJobObject/ActiveProcessLimit}. + *

+ * Note: This is not intended as a real sandbox. It is another level of security, mostly intended to annoy + * security researchers and make their lives more difficult in achieving "remote execution" exploits. + */ + @Override + public void tryInstallExecSandbox() { + // create a new Job + Handle job = kernel.CreateJobObjectW(); + if (job == null) { + throw new UnsupportedOperationException("CreateJobObject: " + kernel.GetLastError()); + } + + try { + // retrieve the current basic limits of the job + int clazz = JOBOBJECT_BASIC_LIMIT_INFORMATION_CLASS; + var info = kernel.newJobObjectBasicLimitInformation(); + if (kernel.QueryInformationJobObject(job, clazz, info) == false) { + throw new UnsupportedOperationException("QueryInformationJobObject: " + kernel.GetLastError()); + } + // modify the number of active processes to be 1 (exactly the one process we will add to the job). + info.setActiveProcessLimit(1); + info.setLimitFlags(JOB_OBJECT_LIMIT_ACTIVE_PROCESS); + if (kernel.SetInformationJobObject(job, clazz, info) == false) { + throw new UnsupportedOperationException("SetInformationJobObject: " + kernel.GetLastError()); + } + // assign ourselves to the job + if (kernel.AssignProcessToJobObject(job, kernel.GetCurrentProcess()) == false) { + throw new UnsupportedOperationException("AssignProcessToJobObject: " + kernel.GetLastError()); + } + } finally { + kernel.CloseHandle(job); + } + + execSandboxState = ExecSandboxState.ALL_THREADS; + logger.debug("Windows ActiveProcessLimit initialization successful"); + } + @Override public ProcessLimits getProcessLimits() { return new ProcessLimits(ProcessLimits.UNKNOWN, ProcessLimits.UNKNOWN, ProcessLimits.UNKNOWN); diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/Kernel32Library.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/Kernel32Library.java index 43337f4532bed..dd786b56087e2 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/Kernel32Library.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/Kernel32Library.java @@ -101,4 +101,65 @@ interface MemoryBasicInformation { * @see SetConsoleCtrlHandler docs */ boolean SetConsoleCtrlHandler(ConsoleCtrlHandler handler, boolean add); + + /** + * Creates or opens a new job object + * + * https://msdn.microsoft.com/en-us/library/windows/desktop/ms682409%28v=vs.85%29.aspx + * Note: the two params to this are omitted because all implementations pass null for them both + * + * @return job handle if the function succeeds + */ + Handle CreateJobObjectW(); + + /** + * Associates a process with an existing job + * + * https://msdn.microsoft.com/en-us/library/windows/desktop/ms681949%28v=vs.85%29.aspx + * + * @param job job handle + * @param process process handle + * @return true if the function succeeds + */ + boolean AssignProcessToJobObject(Handle job, Handle process); + + /** + * Basic limit information for a job object + * + * https://msdn.microsoft.com/en-us/library/windows/desktop/ms684147%28v=vs.85%29.aspx + */ + interface JobObjectBasicLimitInformation { + void setLimitFlags(int v); + + void setActiveProcessLimit(int v); + } + + JobObjectBasicLimitInformation newJobObjectBasicLimitInformation(); + + /** + * Get job limit and state information + * + * https://msdn.microsoft.com/en-us/library/windows/desktop/ms684925%28v=vs.85%29.aspx + * Note: The infoLength parameter is omitted because implementions handle passing it + * Note: The returnLength parameter is omitted because all implementations pass null + * + * @param job job handle + * @param infoClass information class constant + * @param info pointer to information structure + * @return true if the function succeeds + */ + boolean QueryInformationJobObject(Handle job, int infoClass, JobObjectBasicLimitInformation info); + + /** + * Set job limit and state information + * + * https://msdn.microsoft.com/en-us/library/windows/desktop/ms686216%28v=vs.85%29.aspx + * Note: The infoLength parameter is omitted because implementions handle passing it + * + * @param job job handle + * @param infoClass information class constant + * @param info pointer to information structure + * @return true if the function succeeds + */ + boolean SetInformationJobObject(Handle job, int infoClass, JobObjectBasicLimitInformation info); } diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/LinuxCLibrary.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/LinuxCLibrary.java new file mode 100644 index 0000000000000..2a7b10ff3588f --- /dev/null +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/LinuxCLibrary.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.nativeaccess.lib; + +public non-sealed interface LinuxCLibrary extends NativeLibrary { + + /** + * Corresponds to struct sock_filter + * @param code insn + * @param jt number of insn to jump (skip) if true + * @param jf number of insn to jump (skip) if false + * @param k additional data + */ + record SockFilter(short code, byte jt, byte jf, int k) {} + + interface SockFProg { + long address(); + } + + SockFProg newSockFProg(SockFilter filters[]); + + /** + * maps to prctl(2) + */ + int prctl(int option, long arg2, long arg3, long arg4, long arg5); + + /** + * used to call seccomp(2), its too new... + * this is the only way, DON'T use it on some other architecture unless you know wtf you are doing + */ + long syscall(long number, int operation, int flags, long address); +} diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/MacCLibrary.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/MacCLibrary.java new file mode 100644 index 0000000000000..b2b2db9c71c90 --- /dev/null +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/MacCLibrary.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.nativeaccess.lib; + +public non-sealed interface MacCLibrary extends NativeLibrary { + interface ErrorReference {} + + ErrorReference newErrorReference(); + + /** + * maps to sandbox_init(3), since Leopard + */ + int sandbox_init(String profile, long flags, ErrorReference errorbuf); + + /** + * releases memory when an error occurs during initialization (e.g. syntax bug) + */ + void sandbox_free_error(ErrorReference errorbuf); +} diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/NativeLibrary.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/NativeLibrary.java index d8098a78935b8..faa0e861dc63f 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/NativeLibrary.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/NativeLibrary.java @@ -9,4 +9,5 @@ package org.elasticsearch.nativeaccess.lib; /** A marker interface for libraries that can be loaded by {@link org.elasticsearch.nativeaccess.lib.NativeLibraryProvider} */ -public sealed interface NativeLibrary permits JavaLibrary, PosixCLibrary, Kernel32Library, SystemdLibrary, VectorLibrary, ZstdLibrary {} +public sealed interface NativeLibrary permits JavaLibrary, PosixCLibrary, LinuxCLibrary, MacCLibrary, Kernel32Library, SystemdLibrary, + VectorLibrary, ZstdLibrary {} diff --git a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/PosixCLibrary.java b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/PosixCLibrary.java index 96e2a0d0e1cdf..d8db5fa070126 100644 --- a/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/PosixCLibrary.java +++ b/libs/native/src/main/java/org/elasticsearch/nativeaccess/lib/PosixCLibrary.java @@ -26,6 +26,10 @@ interface RLimit { long rlim_cur(); long rlim_max(); + + void rlim_cur(long v); + + void rlim_max(long v); } /** @@ -41,6 +45,8 @@ interface RLimit { */ int getrlimit(int resource, RLimit rlimit); + int setrlimit(int resource, RLimit rlimit); + /** * Lock all the current process's virtual address space into RAM. * @param flags flags determining how memory will be locked @@ -49,6 +55,22 @@ interface RLimit { */ int mlockall(int flags); + interface FStore { + void set_flags(int flags); /* IN: flags word */ + + void set_posmode(int posmode); /* IN: indicates offset field */ + + void set_offset(long offset); /* IN: start of the region */ + + void set_length(long length); /* IN: size of the region */ + + long bytesalloc(); /* OUT: number of bytes allocated */ + } + + FStore newFStore(); + + int fcntl(int fd, int cmd, FStore fst); + /** * Return a string description for an error. * diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkKernel32Library.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkKernel32Library.java index bbfd26bd061d0..f5eb5238dad93 100644 --- a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkKernel32Library.java +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkKernel32Library.java @@ -72,6 +72,22 @@ class JdkKernel32Library implements Kernel32Library { "handle", ConsoleCtrlHandler_handle$fd ); + private static final MethodHandle CreateJobObjectW$mh = downcallHandleWithError( + "CreateJobObjectW", + FunctionDescriptor.of(ADDRESS, ADDRESS, ADDRESS) + ); + private static final MethodHandle AssignProcessToJobObject$mh = downcallHandleWithError( + "AssignProcessToJobObject", + FunctionDescriptor.of(JAVA_BOOLEAN, ADDRESS, ADDRESS) + ); + private static final MethodHandle QueryInformationJobObject$mh = downcallHandleWithError( + "QueryInformationJobObject", + FunctionDescriptor.of(JAVA_BOOLEAN, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT, ADDRESS) + ); + private static final MethodHandle SetInformationJobObject$mh = downcallHandleWithError( + "SetInformationJobObject", + FunctionDescriptor.of(JAVA_BOOLEAN, ADDRESS, JAVA_INT, ADDRESS, JAVA_INT) + ); private static MethodHandle downcallHandleWithError(String function, FunctionDescriptor functionDescriptor) { return downcallHandle(function, functionDescriptor, CAPTURE_GETLASTERROR_OPTION); @@ -146,6 +162,37 @@ public long Type() { } } + static class JdkJobObjectBasicLimitInformation implements JobObjectBasicLimitInformation { + private static final MemoryLayout layout = MemoryLayout.structLayout( + paddingLayout(16), + JAVA_INT, + paddingLayout(20), + JAVA_INT, + paddingLayout(20) + ).withByteAlignment(8); + + private static final VarHandle LimitFlags$vh = varHandleWithoutOffset(layout, groupElement(1)); + private static final VarHandle ActiveProcessLimit$vh = varHandleWithoutOffset(layout, groupElement(3)); + + private final MemorySegment segment; + + JdkJobObjectBasicLimitInformation() { + var arena = Arena.ofAuto(); + this.segment = arena.allocate(layout); + segment.fill((byte) 0); + } + + @Override + public void setLimitFlags(int v) { + LimitFlags$vh.set(segment, v); + } + + @Override + public void setActiveProcessLimit(int v) { + ActiveProcessLimit$vh.set(segment, v); + } + } + private final MemorySegment lastErrorState; JdkKernel32Library() { @@ -262,4 +309,73 @@ public boolean SetConsoleCtrlHandler(ConsoleCtrlHandler handler, boolean add) { throw new AssertionError(t); } } + + @Override + public Handle CreateJobObjectW() { + try { + return new JdkHandle((MemorySegment) CreateJobObjectW$mh.invokeExact(lastErrorState, MemorySegment.NULL, MemorySegment.NULL)); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + @Override + public boolean AssignProcessToJobObject(Handle job, Handle process) { + assert job instanceof JdkHandle; + assert process instanceof JdkHandle; + var jdkJob = (JdkHandle) job; + var jdkProcess = (JdkHandle) process; + + try { + return (boolean) AssignProcessToJobObject$mh.invokeExact(lastErrorState, jdkJob.address, jdkProcess.address); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + @Override + public JobObjectBasicLimitInformation newJobObjectBasicLimitInformation() { + return new JdkJobObjectBasicLimitInformation(); + } + + @Override + public boolean QueryInformationJobObject(Handle job, int infoClass, JobObjectBasicLimitInformation info) { + assert job instanceof JdkHandle; + assert info instanceof JdkJobObjectBasicLimitInformation; + var jdkJob = (JdkHandle) job; + var jdkInfo = (JdkJobObjectBasicLimitInformation) info; + + try { + return (boolean) QueryInformationJobObject$mh.invokeExact( + lastErrorState, + jdkJob.address, + infoClass, + jdkInfo.segment, + (int) jdkInfo.segment.byteSize(), + MemorySegment.NULL + ); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + @Override + public boolean SetInformationJobObject(Handle job, int infoClass, JobObjectBasicLimitInformation info) { + assert job instanceof JdkHandle; + assert info instanceof JdkJobObjectBasicLimitInformation; + var jdkJob = (JdkHandle) job; + var jdkInfo = (JdkJobObjectBasicLimitInformation) info; + + try { + return (boolean) SetInformationJobObject$mh.invokeExact( + lastErrorState, + jdkJob.address, + infoClass, + jdkInfo.segment, + (int) jdkInfo.segment.byteSize() + ); + } catch (Throwable t) { + throw new AssertionError(t); + } + } } diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkLinuxCLibrary.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkLinuxCLibrary.java new file mode 100644 index 0000000000000..700941e7e1db0 --- /dev/null +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkLinuxCLibrary.java @@ -0,0 +1,103 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.nativeaccess.jdk; + +import org.elasticsearch.nativeaccess.lib.LinuxCLibrary; + +import java.lang.foreign.Arena; +import java.lang.foreign.FunctionDescriptor; +import java.lang.foreign.Linker; +import java.lang.foreign.MemoryLayout; +import java.lang.foreign.MemorySegment; +import java.lang.invoke.MethodHandle; + +import static java.lang.foreign.MemoryLayout.paddingLayout; +import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_BYTE; +import static java.lang.foreign.ValueLayout.JAVA_INT; +import static java.lang.foreign.ValueLayout.JAVA_LONG; +import static java.lang.foreign.ValueLayout.JAVA_SHORT; +import static org.elasticsearch.nativeaccess.jdk.JdkPosixCLibrary.CAPTURE_ERRNO_OPTION; +import static org.elasticsearch.nativeaccess.jdk.JdkPosixCLibrary.downcallHandleWithErrno; +import static org.elasticsearch.nativeaccess.jdk.JdkPosixCLibrary.errnoState; +import static org.elasticsearch.nativeaccess.jdk.LinkerHelper.downcallHandle; + +class JdkLinuxCLibrary implements LinuxCLibrary { + private static final MethodHandle prctl$mh; + static { + try { + prctl$mh = downcallHandleWithErrno( + "prctl", + FunctionDescriptor.of(JAVA_INT, JAVA_INT, JAVA_LONG, JAVA_LONG, JAVA_LONG, JAVA_LONG) + ); + } catch (UnsatisfiedLinkError e) { + throw new UnsupportedOperationException( + "seccomp unavailable: could not link methods. requires kernel 3.5+ " + + "with CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER compiled in" + ); + } + } + private static final MethodHandle syscall$mh = downcallHandle( + "syscall", + FunctionDescriptor.of(JAVA_LONG, JAVA_LONG, JAVA_INT, JAVA_INT, JAVA_LONG), + CAPTURE_ERRNO_OPTION, + Linker.Option.firstVariadicArg(1) + ); + + private static class JdkSockFProg implements SockFProg { + private static final MemoryLayout layout = MemoryLayout.structLayout(JAVA_SHORT, paddingLayout(6), ADDRESS); + + private final MemorySegment segment; + + JdkSockFProg(SockFilter filters[]) { + Arena arena = Arena.ofAuto(); + this.segment = arena.allocate(layout); + var instSegment = arena.allocate(filters.length * 8L); + segment.set(JAVA_SHORT, 0, (short) filters.length); + segment.set(ADDRESS, 8, instSegment); + + int offset = 0; + for (SockFilter f : filters) { + instSegment.set(JAVA_SHORT, offset, f.code()); + instSegment.set(JAVA_BYTE, offset + 2, f.jt()); + instSegment.set(JAVA_BYTE, offset + 3, f.jf()); + instSegment.set(JAVA_INT, offset + 4, f.k()); + offset += 8; + } + } + + @Override + public long address() { + return segment.address(); + } + } + + @Override + public SockFProg newSockFProg(SockFilter[] filters) { + return new JdkSockFProg(filters); + } + + @Override + public int prctl(int option, long arg2, long arg3, long arg4, long arg5) { + try { + return (int) prctl$mh.invokeExact(errnoState, option, arg2, arg3, arg4, arg5); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + @Override + public long syscall(long number, int operation, int flags, long address) { + try { + return (long) syscall$mh.invokeExact(errnoState, number, operation, flags, address); + } catch (Throwable t) { + throw new AssertionError(t); + } + } +} diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkMacCLibrary.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkMacCLibrary.java new file mode 100644 index 0000000000000..b946ca3ca4353 --- /dev/null +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkMacCLibrary.java @@ -0,0 +1,73 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.nativeaccess.jdk; + +import org.elasticsearch.nativeaccess.lib.MacCLibrary; + +import java.lang.foreign.Arena; +import java.lang.foreign.FunctionDescriptor; +import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; +import java.lang.invoke.MethodHandle; + +import static java.lang.foreign.ValueLayout.ADDRESS; +import static java.lang.foreign.ValueLayout.JAVA_INT; +import static java.lang.foreign.ValueLayout.JAVA_LONG; +import static org.elasticsearch.nativeaccess.jdk.LinkerHelper.downcallHandle; + +class JdkMacCLibrary implements MacCLibrary { + + private static final MethodHandle sandbox_init$mh = downcallHandle( + "sandbox_init", + FunctionDescriptor.of(JAVA_INT, ADDRESS, JAVA_LONG, ADDRESS) + ); + private static final MethodHandle sandbox_free_error$mh = downcallHandle("sandbox_free_error", FunctionDescriptor.ofVoid(ADDRESS)); + + private static class JdkErrorReference implements ErrorReference { + final Arena arena = Arena.ofConfined(); + final MemorySegment segment = arena.allocate(ValueLayout.ADDRESS); + + MemorySegment deref() { + return segment.get(ADDRESS, 0); + } + + @Override + public String toString() { + return deref().reinterpret(Long.MAX_VALUE).getUtf8String(0); + } + } + + @Override + public ErrorReference newErrorReference() { + return new JdkErrorReference(); + } + + @Override + public int sandbox_init(String profile, long flags, ErrorReference errorbuf) { + assert errorbuf instanceof JdkErrorReference; + var jdkErrorbuf = (JdkErrorReference) errorbuf; + try (Arena arena = Arena.ofConfined()) { + MemorySegment nativeProfile = MemorySegmentUtil.allocateString(arena, profile); + return (int) sandbox_init$mh.invokeExact(nativeProfile, flags, jdkErrorbuf.segment); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + + @Override + public void sandbox_free_error(ErrorReference errorbuf) { + assert errorbuf instanceof JdkErrorReference; + var jdkErrorbuf = (JdkErrorReference) errorbuf; + try { + sandbox_free_error$mh.invokeExact(jdkErrorbuf.deref()); + } catch (Throwable t) { + throw new AssertionError(t); + } + } +} diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkNativeLibraryProvider.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkNativeLibraryProvider.java index d76170a55284c..cbd43a394379b 100644 --- a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkNativeLibraryProvider.java +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkNativeLibraryProvider.java @@ -10,6 +10,8 @@ import org.elasticsearch.nativeaccess.lib.JavaLibrary; import org.elasticsearch.nativeaccess.lib.Kernel32Library; +import org.elasticsearch.nativeaccess.lib.LinuxCLibrary; +import org.elasticsearch.nativeaccess.lib.MacCLibrary; import org.elasticsearch.nativeaccess.lib.NativeLibraryProvider; import org.elasticsearch.nativeaccess.lib.PosixCLibrary; import org.elasticsearch.nativeaccess.lib.SystemdLibrary; @@ -28,6 +30,10 @@ public JdkNativeLibraryProvider() { JdkJavaLibrary::new, PosixCLibrary.class, JdkPosixCLibrary::new, + LinuxCLibrary.class, + JdkLinuxCLibrary::new, + MacCLibrary.class, + JdkMacCLibrary::new, Kernel32Library.class, JdkKernel32Library::new, SystemdLibrary.class, diff --git a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkPosixCLibrary.java b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkPosixCLibrary.java index 43ec9425ccfaa..1a65225873c1d 100644 --- a/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkPosixCLibrary.java +++ b/libs/native/src/main21/java/org/elasticsearch/nativeaccess/jdk/JdkPosixCLibrary.java @@ -43,7 +43,12 @@ class JdkPosixCLibrary implements PosixCLibrary { "getrlimit", FunctionDescriptor.of(JAVA_INT, JAVA_INT, ADDRESS) ); + private static final MethodHandle setrlimit$mh = downcallHandleWithErrno( + "setrlimit", + FunctionDescriptor.of(JAVA_INT, JAVA_INT, ADDRESS) + ); private static final MethodHandle mlockall$mh = downcallHandleWithErrno("mlockall", FunctionDescriptor.of(JAVA_INT, JAVA_INT)); + private static final MethodHandle fcntl$mh = downcallHandle("fcntl", FunctionDescriptor.of(JAVA_INT, JAVA_INT, JAVA_INT, ADDRESS)); static final MemorySegment errnoState = Arena.ofAuto().allocate(CAPTURE_ERRNO_LAYOUT); @@ -91,6 +96,17 @@ public int getrlimit(int resource, RLimit rlimit) { } } + @Override + public int setrlimit(int resource, RLimit rlimit) { + assert rlimit instanceof JdkRLimit; + var jdkRlimit = (JdkRLimit) rlimit; + try { + return (int) setrlimit$mh.invokeExact(errnoState, resource, jdkRlimit.segment); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + @Override public int mlockall(int flags) { try { @@ -100,6 +116,22 @@ public int mlockall(int flags) { } } + @Override + public FStore newFStore() { + return new JdkFStore(); + } + + @Override + public int fcntl(int fd, int cmd, FStore fst) { + assert fst instanceof JdkFStore; + var jdkFst = (JdkFStore) fst; + try { + return (int) fcntl$mh.invokeExact(errnoState, fd, cmd, jdkFst.segment); + } catch (Throwable t) { + throw new AssertionError(t); + } + } + static class JdkRLimit implements RLimit { private static final MemoryLayout layout = MemoryLayout.structLayout(JAVA_LONG, JAVA_LONG); private static final VarHandle rlim_cur$vh = varHandleWithoutOffset(layout, groupElement(0)); @@ -122,9 +154,60 @@ public long rlim_max() { return (long) rlim_max$vh.get(segment); } + @Override + public void rlim_cur(long v) { + rlim_cur$vh.set(segment, v); + } + + @Override + public void rlim_max(long v) { + rlim_max$vh.set(segment, v); + } + @Override public String toString() { return "JdkRLimit[rlim_cur=" + rlim_cur() + ", rlim_max=" + rlim_max(); } } + + private static class JdkFStore implements FStore { + private static final MemoryLayout layout = MemoryLayout.structLayout(JAVA_INT, JAVA_INT, JAVA_LONG, JAVA_LONG, JAVA_LONG); + private static final VarHandle st_flags$vh = layout.varHandle(groupElement(0)); + private static final VarHandle st_posmode$vh = layout.varHandle(groupElement(1)); + private static final VarHandle st_offset$vh = layout.varHandle(groupElement(2)); + private static final VarHandle st_length$vh = layout.varHandle(groupElement(3)); + private static final VarHandle st_bytesalloc$vh = layout.varHandle(groupElement(4)); + + private final MemorySegment segment; + + JdkFStore() { + var arena = Arena.ofAuto(); + this.segment = arena.allocate(layout); + } + + @Override + public void set_flags(int flags) { + st_flags$vh.set(segment, flags); + } + + @Override + public void set_posmode(int posmode) { + st_posmode$vh.set(segment, posmode); + } + + @Override + public void set_offset(long offset) { + st_offset$vh.get(segment, offset); + } + + @Override + public void set_length(long length) { + st_length$vh.set(segment, length); + } + + @Override + public long bytesalloc() { + return (long) st_bytesalloc$vh.get(segment); + } + } } diff --git a/qa/evil-tests/src/test/java/org/elasticsearch/bootstrap/SystemCallFilterTests.java b/libs/native/src/test/java/org/elasticsearch/nativeaccess/SystemCallFilterTests.java similarity index 84% rename from qa/evil-tests/src/test/java/org/elasticsearch/bootstrap/SystemCallFilterTests.java rename to libs/native/src/test/java/org/elasticsearch/nativeaccess/SystemCallFilterTests.java index c62522880869b..d4bac13990898 100644 --- a/qa/evil-tests/src/test/java/org/elasticsearch/bootstrap/SystemCallFilterTests.java +++ b/libs/native/src/test/java/org/elasticsearch/nativeaccess/SystemCallFilterTests.java @@ -6,12 +6,16 @@ * Side Public License, v 1. */ -package org.elasticsearch.bootstrap; +package org.elasticsearch.nativeaccess; import org.apache.lucene.util.Constants; import org.elasticsearch.test.ESTestCase; +import static org.apache.lucene.tests.util.LuceneTestCase.assumeTrue; +import static org.junit.Assert.fail; + /** Simple tests system call filter is working. */ +@ESTestCase.WithoutSecurityManager public class SystemCallFilterTests extends ESTestCase { /** command to try to run in tests */ @@ -20,15 +24,18 @@ public class SystemCallFilterTests extends ESTestCase { @Override public void setUp() throws Exception { super.setUp(); - assumeTrue("requires system call filter installation", Natives.isSystemCallFilterInstalled()); + assumeTrue( + "requires system call filter installation", + NativeAccess.instance().getExecSandboxState() != NativeAccess.ExecSandboxState.NONE + ); // otherwise security manager will block the execution, no fun assumeTrue("cannot test with security manager enabled", System.getSecurityManager() == null); // otherwise, since we don't have TSYNC support, rules are not applied to the test thread // (randomizedrunner class initialization happens in its own thread, after the test thread is created) // instead we just forcefully run it for the test thread here. - if (JNANatives.LOCAL_SYSTEM_CALL_FILTER_ALL == false) { + if (NativeAccess.instance().getExecSandboxState() != NativeAccess.ExecSandboxState.ALL_THREADS) { try { - SystemCallFilter.init(createTempDir()); + NativeAccess.instance().tryInstallExecSandbox(); } catch (Exception e) { throw new RuntimeException("unable to forcefully apply system call filter to test thread", e); } diff --git a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/AbstractDataStreamIT.java b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/AbstractDataStreamIT.java index ca33f08324539..027ac7c736c8a 100644 --- a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/AbstractDataStreamIT.java +++ b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/AbstractDataStreamIT.java @@ -25,6 +25,7 @@ import java.io.IOException; import java.util.List; import java.util.Map; +import java.util.concurrent.TimeUnit; /** * This base class provides the boilerplate to simplify the development of integration tests. @@ -53,7 +54,7 @@ static void waitForIndexTemplate(RestClient client, String indexTemplate) throws } catch (ResponseException e) { fail(e.getMessage()); } - }); + }, 15, TimeUnit.SECONDS); } static void createDataStream(RestClient client, String name) throws IOException { diff --git a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/DataStreamUpgradeRestIT.java b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/DataStreamUpgradeRestIT.java index f447e5b80f8c8..39cdf77d04810 100644 --- a/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/DataStreamUpgradeRestIT.java +++ b/modules/data-streams/src/javaRestTest/java/org/elasticsearch/datastreams/DataStreamUpgradeRestIT.java @@ -22,6 +22,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.concurrent.TimeUnit; import java.util.stream.Collectors; import static org.elasticsearch.rest.action.search.RestSearchAction.TOTAL_HITS_AS_INT_PARAM; @@ -306,6 +307,6 @@ private void waitForLogsComponentTemplateInitialization() throws Exception { // Throw the exception, if it was an error we did not anticipate throw responseException; } - }); + }, 15, TimeUnit.SECONDS); } } diff --git a/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderIT.java b/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderIT.java index 9dcd8abc7bc57..9eab00fbadf20 100644 --- a/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderIT.java +++ b/modules/ingest-geoip/src/internalClusterTest/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderIT.java @@ -242,7 +242,7 @@ public void testGeoIpDatabasesDownload() throws Exception { Set.of("GeoLite2-ASN.mmdb", "GeoLite2-City.mmdb", "GeoLite2-Country.mmdb", "MyCustomGeoLite2-City.mmdb"), state.getDatabases().keySet() ); - GeoIpTaskState.Metadata metadata = state.get(id); + GeoIpTaskState.Metadata metadata = state.getDatabases().get(id); int size = metadata.lastChunk() - metadata.firstChunk() + 1; assertResponse( prepareSearch(GeoIpDownloader.DATABASES_INDEX).setSize(size) diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDownloader.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDownloader.java index 895c9315d2325..13394a2a0c7cc 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDownloader.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpDownloader.java @@ -170,23 +170,28 @@ private List fetchDatabasesOverview() throws IOException { } // visible for testing - void processDatabase(Map databaseInfo) { + void processDatabase(final Map databaseInfo) { String name = databaseInfo.get("name").toString().replace(".tgz", "") + ".mmdb"; String md5 = (String) databaseInfo.get("md5_hash"); - if (state.contains(name) && Objects.equals(md5, state.get(name).md5())) { - updateTimestamp(name, state.get(name)); - return; - } - logger.debug("downloading geoip database [{}]", name); String url = databaseInfo.get("url").toString(); if (url.startsWith("http") == false) { // relative url, add it after last slash (i.e. resolve sibling) or at the end if there's no slash after http[s]:// int lastSlash = endpoint.substring(8).lastIndexOf('/'); url = (lastSlash != -1 ? endpoint.substring(0, lastSlash + 8) : endpoint) + "/" + url; } + processDatabase(name, md5, url); + } + + private void processDatabase(final String name, final String md5, final String url) { + Metadata metadata = state.getDatabases().getOrDefault(name, Metadata.EMPTY); + if (Objects.equals(metadata.md5(), md5)) { + updateTimestamp(name, metadata); + return; + } + logger.debug("downloading geoip database [{}]", name); long start = System.currentTimeMillis(); try (InputStream is = httpClient.get(url)) { - int firstChunk = state.contains(name) ? state.get(name).lastChunk() + 1 : 0; + int firstChunk = metadata.lastChunk() + 1; // if there is no metadata, then Metadata.EMPTY.lastChunk() + 1 = 0 int lastChunk = indexChunks(name, is, firstChunk, md5, start); if (lastChunk > firstChunk) { state = state.put(name, new Metadata(start, firstChunk, lastChunk - 1, md5, start)); @@ -313,22 +318,19 @@ public void requestReschedule() { } private void cleanDatabases() { - long expiredDatabases = state.getDatabases() + List> expiredDatabases = state.getDatabases() .entrySet() .stream() .filter(e -> e.getValue().isValid(clusterService.state().metadata().settings()) == false) - .peek(e -> { - String name = e.getKey(); - Metadata meta = e.getValue(); - deleteOldChunks(name, meta.lastChunk() + 1); - state = state.put( - name, - new Metadata(meta.lastUpdate(), meta.firstChunk(), meta.lastChunk(), meta.md5(), meta.lastCheck() - 1) - ); - updateTaskState(); - }) - .count(); - stats = stats.expiredDatabases((int) expiredDatabases); + .toList(); + expiredDatabases.forEach(e -> { + String name = e.getKey(); + Metadata meta = e.getValue(); + deleteOldChunks(name, meta.lastChunk() + 1); + state = state.put(name, new Metadata(meta.lastUpdate(), meta.firstChunk(), meta.lastChunk(), meta.md5(), meta.lastCheck() - 1)); + updateTaskState(); + }); + stats = stats.expiredDatabases(expiredDatabases.size()); } @Override diff --git a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpTaskState.java b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpTaskState.java index d55f517b46e24..a405d90b24dcc 100644 --- a/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpTaskState.java +++ b/modules/ingest-geoip/src/main/java/org/elasticsearch/ingest/geoip/GeoIpTaskState.java @@ -84,14 +84,6 @@ public Map getDatabases() { return databases; } - public boolean contains(String name) { - return databases.containsKey(name); - } - - public Metadata get(String name) { - return databases.get(name); - } - @Override public boolean equals(Object o) { if (this == o) return true; @@ -142,7 +134,13 @@ public void writeTo(StreamOutput out) throws IOException { record Metadata(long lastUpdate, int firstChunk, int lastChunk, String md5, long lastCheck) implements ToXContentObject { - static final String NAME = GEOIP_DOWNLOADER + "-metadata"; + /** + * An empty Metadata object useful for getOrDefault -type calls. Crucially, the 'lastChunk' is -1, so it's safe to use + * with logic that says the new firstChunk is the old lastChunk + 1. + */ + static Metadata EMPTY = new Metadata(-1, -1, -1, "", -1); + + private static final String NAME = GEOIP_DOWNLOADER + "-metadata"; private static final ParseField LAST_CHECK = new ParseField("last_check"); private static final ParseField LAST_UPDATE = new ParseField("last_update"); private static final ParseField FIRST_CHUNK = new ParseField("first_chunk"); diff --git a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java index 9cc5405c1b617..6a83fe69473f7 100644 --- a/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java +++ b/modules/ingest-geoip/src/test/java/org/elasticsearch/ingest/geoip/GeoIpDownloaderTests.java @@ -30,11 +30,17 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.reindex.BulkByScrollResponse; +import org.elasticsearch.index.reindex.DeleteByQueryAction; +import org.elasticsearch.index.reindex.DeleteByQueryRequest; import org.elasticsearch.ingest.geoip.stats.GeoIpDownloaderStats; import org.elasticsearch.node.Node; +import org.elasticsearch.persistent.PersistentTaskResponse; import org.elasticsearch.persistent.PersistentTaskState; import org.elasticsearch.persistent.PersistentTasksCustomMetadata; import org.elasticsearch.persistent.PersistentTasksCustomMetadata.PersistentTask; +import org.elasticsearch.persistent.PersistentTasksService; +import org.elasticsearch.persistent.UpdatePersistentTaskStatusAction; import org.elasticsearch.telemetry.metric.MeterRegistry; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.client.NoOpClient; @@ -49,6 +55,9 @@ import java.io.ByteArrayOutputStream; import java.io.IOException; import java.io.InputStream; +import java.nio.charset.StandardCharsets; +import java.time.Instant; +import java.time.temporal.ChronoUnit; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -63,6 +72,8 @@ import static org.elasticsearch.ingest.geoip.GeoIpDownloader.MAX_CHUNK_SIZE; import static org.elasticsearch.tasks.TaskId.EMPTY_TASK_ID; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verifyNoInteractions; import static org.mockito.Mockito.when; @@ -76,8 +87,9 @@ public class GeoIpDownloaderTests extends ESTestCase { private GeoIpDownloader geoIpDownloader; @Before - public void setup() { + public void setup() throws IOException { httpClient = mock(HttpClient.class); + when(httpClient.getBytes(anyString())).thenReturn("[]".getBytes(StandardCharsets.UTF_8)); clusterService = mock(ClusterService.class); threadPool = new ThreadPool(Settings.builder().put(Node.NODE_NAME_SETTING.getKey(), "test").build(), MeterRegistry.NOOP); when(clusterService.getClusterSettings()).thenReturn( @@ -109,7 +121,13 @@ public void setup() { () -> GeoIpDownloaderTaskExecutor.POLL_INTERVAL_SETTING.getDefault(Settings.EMPTY), () -> GeoIpDownloaderTaskExecutor.EAGER_DOWNLOAD_SETTING.getDefault(Settings.EMPTY), () -> true - ); + ) { + { + GeoIpTaskParams geoIpTaskParams = mock(GeoIpTaskParams.class); + when(geoIpTaskParams.getWriteableName()).thenReturn(GeoIpDownloader.GEOIP_DOWNLOADER); + init(new PersistentTasksService(clusterService, threadPool, client), null, null, 0); + } + }; } @After @@ -290,8 +308,8 @@ int indexChunks(String name, InputStream is, int chunk, String expectedMd5, long @Override void updateTaskState() { - assertEquals(0, state.get("test.mmdb").firstChunk()); - assertEquals(10, state.get("test.mmdb").lastChunk()); + assertEquals(0, state.getDatabases().get("test.mmdb").firstChunk()); + assertEquals(10, state.getDatabases().get("test.mmdb").lastChunk()); } @Override @@ -341,8 +359,8 @@ int indexChunks(String name, InputStream is, int chunk, String expectedMd5, long @Override void updateTaskState() { - assertEquals(9, state.get("test.mmdb").firstChunk()); - assertEquals(10, state.get("test.mmdb").lastChunk()); + assertEquals(9, state.getDatabases().get("test.mmdb").firstChunk()); + assertEquals(10, state.getDatabases().get("test.mmdb").lastChunk()); } @Override @@ -541,6 +559,78 @@ public void testUpdateDatabasesIndexNotReady() { verifyNoInteractions(httpClient); } + public void testThatRunDownloaderDeletesExpiredDatabases() { + /* + * This test puts some expired databases and some non-expired ones into the GeoIpTaskState, and then calls runDownloader(), making + * sure that the expired databases have been deleted. + */ + AtomicInteger updatePersistentTaskStateCount = new AtomicInteger(0); + AtomicInteger deleteCount = new AtomicInteger(0); + int expiredDatabasesCount = randomIntBetween(1, 100); + int unexpiredDatabasesCount = randomIntBetween(0, 100); + Map databases = new HashMap<>(); + for (int i = 0; i < expiredDatabasesCount; i++) { + databases.put("expiredDatabase" + i, newGeoIpTaskStateMetadata(true)); + } + for (int i = 0; i < unexpiredDatabasesCount; i++) { + databases.put("unexpiredDatabase" + i, newGeoIpTaskStateMetadata(false)); + } + GeoIpTaskState geoIpTaskState = new GeoIpTaskState(databases); + geoIpDownloader.setState(geoIpTaskState); + client.addHandler( + UpdatePersistentTaskStatusAction.INSTANCE, + (UpdatePersistentTaskStatusAction.Request request, ActionListener taskResponseListener) -> { + PersistentTasksCustomMetadata.Assignment assignment = mock(PersistentTasksCustomMetadata.Assignment.class); + PersistentTasksCustomMetadata.PersistentTask persistentTask = new PersistentTasksCustomMetadata.PersistentTask<>( + GeoIpDownloader.GEOIP_DOWNLOADER, + GeoIpDownloader.GEOIP_DOWNLOADER, + new GeoIpTaskParams(), + request.getAllocationId(), + assignment + ); + updatePersistentTaskStateCount.incrementAndGet(); + taskResponseListener.onResponse(new PersistentTaskResponse(new PersistentTask<>(persistentTask, request.getState()))); + } + ); + client.addHandler( + DeleteByQueryAction.INSTANCE, + (DeleteByQueryRequest request, ActionListener flushResponseActionListener) -> { + deleteCount.incrementAndGet(); + } + ); + geoIpDownloader.runDownloader(); + assertThat(geoIpDownloader.getStatus().getExpiredDatabases(), equalTo(expiredDatabasesCount)); + for (int i = 0; i < expiredDatabasesCount; i++) { + // This currently fails because we subtract one millisecond from the lastChecked time + // assertThat(geoIpDownloader.state.getDatabases().get("expiredDatabase" + i).lastCheck(), equalTo(-1L)); + } + for (int i = 0; i < unexpiredDatabasesCount; i++) { + assertThat( + geoIpDownloader.state.getDatabases().get("unexpiredDatabase" + i).lastCheck(), + greaterThanOrEqualTo(Instant.now().minus(30, ChronoUnit.DAYS).toEpochMilli()) + ); + } + assertThat(deleteCount.get(), equalTo(expiredDatabasesCount)); + assertThat(updatePersistentTaskStateCount.get(), equalTo(expiredDatabasesCount)); + geoIpDownloader.runDownloader(); + /* + * The following two lines assert current behavior that might not be desirable -- we continue to delete expired databases every + * time that runDownloader runs. This seems unnecessary. + */ + assertThat(deleteCount.get(), equalTo(expiredDatabasesCount * 2)); + assertThat(updatePersistentTaskStateCount.get(), equalTo(expiredDatabasesCount * 2)); + } + + private GeoIpTaskState.Metadata newGeoIpTaskStateMetadata(boolean expired) { + Instant lastChecked; + if (expired) { + lastChecked = Instant.now().minus(randomIntBetween(31, 100), ChronoUnit.DAYS); + } else { + lastChecked = Instant.now().minus(randomIntBetween(0, 29), ChronoUnit.DAYS); + } + return new GeoIpTaskState.Metadata(0, 0, 0, randomAlphaOfLength(20), lastChecked.toEpochMilli()); + } + private static class MockClient extends NoOpClient { private final Map, BiConsumer>> handlers = new HashMap<>(); diff --git a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml index 3eb686bda2174..4c195a0e32623 100644 --- a/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml +++ b/modules/lang-painless/src/yamlRestTest/resources/rest-api-spec/test/painless/146_dense_vector_bit_basic.yml @@ -8,6 +8,8 @@ setup: indices.create: index: test-index body: + settings: + number_of_shards: 1 mappings: properties: vector: @@ -107,7 +109,6 @@ setup: headers: Content-Type: application/json search: - rest_total_hits_as_int: true body: query: script_score: @@ -138,7 +139,6 @@ setup: headers: Content-Type: application/json search: - rest_total_hits_as_int: true body: query: script_score: @@ -152,7 +152,6 @@ setup: headers: Content-Type: application/json search: - rest_total_hits_as_int: true body: query: script_score: @@ -167,7 +166,6 @@ setup: headers: Content-Type: application/json search: - rest_total_hits_as_int: true body: query: script_score: diff --git a/modules/legacy-geo/src/main/java/org/elasticsearch/legacygeo/mapper/LegacyGeoShapeFieldMapper.java b/modules/legacy-geo/src/main/java/org/elasticsearch/legacygeo/mapper/LegacyGeoShapeFieldMapper.java index e03d7e2fd5384..2808dae31239c 100644 --- a/modules/legacy-geo/src/main/java/org/elasticsearch/legacygeo/mapper/LegacyGeoShapeFieldMapper.java +++ b/modules/legacy-geo/src/main/java/org/elasticsearch/legacygeo/mapper/LegacyGeoShapeFieldMapper.java @@ -81,7 +81,7 @@ *

* "field" : "POLYGON ((100.0 0.0, 101.0 0.0, 101.0 1.0, 100.0 1.0, 100.0 0.0)) * - * @deprecated use {@link org.elasticsearch.index.mapper.GeoShapeFieldMapper} + * @deprecated use the field mapper in the spatial module */ @Deprecated public class LegacyGeoShapeFieldMapper extends AbstractShapeGeometryFieldMapper> { diff --git a/modules/reindex/src/test/java/org/elasticsearch/reindex/DeleteByQueryConcurrentTests.java b/modules/reindex/src/test/java/org/elasticsearch/reindex/DeleteByQueryConcurrentTests.java index 323b829fe93ff..190616b9980f0 100644 --- a/modules/reindex/src/test/java/org/elasticsearch/reindex/DeleteByQueryConcurrentTests.java +++ b/modules/reindex/src/test/java/org/elasticsearch/reindex/DeleteByQueryConcurrentTests.java @@ -11,11 +11,9 @@ import org.elasticsearch.action.index.IndexRequestBuilder; import org.elasticsearch.index.query.MatchQueryBuilder; import org.elasticsearch.index.query.QueryBuilders; -import org.elasticsearch.index.reindex.BulkByScrollResponse; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.atomic.AtomicLong; import static org.elasticsearch.index.query.QueryBuilders.matchQuery; @@ -26,44 +24,29 @@ public class DeleteByQueryConcurrentTests extends ReindexTestCase { public void testConcurrentDeleteByQueriesOnDifferentDocs() throws Throwable { - final Thread[] threads = new Thread[scaledRandomIntBetween(2, 5)]; + final int threadCount = scaledRandomIntBetween(2, 5); final long docs = randomIntBetween(1, 50); List builders = new ArrayList<>(); for (int i = 0; i < docs; i++) { - for (int t = 0; t < threads.length; t++) { + for (int t = 0; t < threadCount; t++) { builders.add(prepareIndex("test").setSource("field", t)); } } indexRandom(true, true, true, builders); - final CountDownLatch start = new CountDownLatch(1); - for (int t = 0; t < threads.length; t++) { - final int threadNum = t; - assertHitCount(prepareSearch("test").setSize(0).setQuery(QueryBuilders.termQuery("field", threadNum)), docs); - - Runnable r = () -> { - try { - start.await(); - - assertThat( - deleteByQuery().source("_all").filter(termQuery("field", threadNum)).refresh(true).get(), - matcher().deleted(docs) - ); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - }; - threads[t] = new Thread(r); - threads[t].start(); + for (int t = 0; t < threadCount; t++) { + assertHitCount(prepareSearch("test").setSize(0).setQuery(QueryBuilders.termQuery("field", t)), docs); } - - start.countDown(); - for (Thread thread : threads) { - thread.join(); - } - - for (int t = 0; t < threads.length; t++) { + startInParallel( + threadCount, + threadNum -> assertThat( + deleteByQuery().source("_all").filter(termQuery("field", threadNum)).refresh(true).get(), + matcher().deleted(docs) + ) + ); + + for (int t = 0; t < threadCount; t++) { assertHitCount(prepareSearch("test").setSize(0).setQuery(QueryBuilders.termQuery("field", t)), 0); } } @@ -77,33 +60,12 @@ public void testConcurrentDeleteByQueriesOnSameDocs() throws Throwable { } indexRandom(true, true, true, builders); - final Thread[] threads = new Thread[scaledRandomIntBetween(2, 9)]; + final int threadCount = scaledRandomIntBetween(2, 9); - final CountDownLatch start = new CountDownLatch(1); final MatchQueryBuilder query = matchQuery("foo", "bar"); final AtomicLong deleted = new AtomicLong(0); - - for (int t = 0; t < threads.length; t++) { - Runnable r = () -> { - try { - start.await(); - - BulkByScrollResponse response = deleteByQuery().source("test").filter(query).refresh(true).get(); - // Some deletions might fail due to version conflict, but - // what matters here is the total of successful deletions - deleted.addAndGet(response.getDeleted()); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - } - }; - threads[t] = new Thread(r); - threads[t].start(); - } - - start.countDown(); - for (Thread thread : threads) { - thread.join(); - } + // Some deletions might fail due to version conflict, but what matters here is the total of successful deletions + startInParallel(threadCount, i -> deleted.addAndGet(deleteByQuery().source("test").filter(query).refresh(true).get().getDeleted())); assertHitCount(prepareSearch("test").setSize(0), 0L); assertThat(deleted.get(), equalTo(docs)); diff --git a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Repository.java b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Repository.java index d53c379a37644..72b48c5903629 100644 --- a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Repository.java +++ b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3Repository.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.cluster.metadata.RepositoryMetadata; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ReferenceDocs; import org.elasticsearch.common.Strings; import org.elasticsearch.common.blobstore.BlobPath; import org.elasticsearch.common.blobstore.BlobStore; @@ -443,4 +444,19 @@ protected void doClose() { } super.doClose(); } + + @Override + public String getAnalysisFailureExtraDetail() { + return Strings.format( + """ + Elasticsearch observed the storage system underneath this repository behaved incorrectly which indicates it is not \ + suitable for use with Elasticsearch snapshots. Typically this happens when using storage other than AWS S3 which \ + incorrectly claims to be S3-compatible. If so, please report this incompatibility to your storage supplier. Do not report \ + Elasticsearch issues involving storage systems which claim to be S3-compatible unless you can demonstrate that the same \ + issue exists when using a genuine AWS S3 repository. See [%s] for further information about repository analysis, and [%s] \ + for further information about support for S3-compatible repository implementations.""", + ReferenceDocs.SNAPSHOT_REPOSITORY_ANALYSIS, + ReferenceDocs.S3_COMPATIBLE_REPOSITORIES + ); + } } diff --git a/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3RepositoryTests.java b/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3RepositoryTests.java index fcb0e82505dac..4bbc791e5fe21 100644 --- a/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3RepositoryTests.java +++ b/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3RepositoryTests.java @@ -11,6 +11,7 @@ import com.amazonaws.services.s3.AbstractAmazonS3; import org.elasticsearch.cluster.metadata.RepositoryMetadata; +import org.elasticsearch.common.ReferenceDocs; import org.elasticsearch.common.settings.ClusterSettings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeUnit; @@ -28,6 +29,7 @@ import java.util.Map; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.not; @@ -152,4 +154,24 @@ private S3Repository createS3Repo(RepositoryMetadata metadata) { ); } + public void testAnalysisFailureDetail() { + try ( + S3Repository s3repo = createS3Repo( + new RepositoryMetadata("dummy-repo", "mock", Settings.builder().put(S3Repository.BUCKET_SETTING.getKey(), "bucket").build()) + ) + ) { + assertThat( + s3repo.getAnalysisFailureExtraDetail(), + allOf( + containsString("storage system underneath this repository behaved incorrectly"), + containsString("incorrectly claims to be S3-compatible"), + containsString("report this incompatibility to your storage supplier"), + containsString("unless you can demonstrate that the same issue exists when using a genuine AWS S3 repository"), + containsString(ReferenceDocs.SNAPSHOT_REPOSITORY_ANALYSIS.toString()), + containsString(ReferenceDocs.S3_COMPATIBLE_REPOSITORIES.toString()) + ) + ); + } + } + } diff --git a/muted-tests.yml b/muted-tests.yml index d8eba8ad2dba6..ac0b03cc4b4fa 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -4,7 +4,8 @@ tests: method: "testGuessIsDayFirstFromLocale" - class: "org.elasticsearch.test.rest.ClientYamlTestSuiteIT" issue: "https://github.com/elastic/elasticsearch/issues/108857" - method: "test {yaml=search/180_locale_dependent_mapping/Test Index and Search locale dependent mappings / dates}" + method: "test {yaml=search/180_locale_dependent_mapping/Test Index and Search locale\ + \ dependent mappings / dates}" - class: "org.elasticsearch.upgrades.SearchStatesIT" issue: "https://github.com/elastic/elasticsearch/issues/108991" method: "testCanMatch" @@ -13,7 +14,8 @@ tests: method: "testTrainedModelInference" - class: "org.elasticsearch.xpack.security.CoreWithSecurityClientYamlTestSuiteIT" issue: "https://github.com/elastic/elasticsearch/issues/109188" - method: "test {yaml=search/180_locale_dependent_mapping/Test Index and Search locale dependent mappings / dates}" + method: "test {yaml=search/180_locale_dependent_mapping/Test Index and Search locale\ + \ dependent mappings / dates}" - class: "org.elasticsearch.xpack.esql.qa.mixed.EsqlClientYamlIT" issue: "https://github.com/elastic/elasticsearch/issues/109189" method: "test {p0=esql/70_locale/Date format with Italian locale}" @@ -28,7 +30,8 @@ tests: method: "testTimestampFieldTypeExposedByAllIndicesServices" - class: "org.elasticsearch.analysis.common.CommonAnalysisClientYamlTestSuiteIT" issue: "https://github.com/elastic/elasticsearch/issues/109318" - method: "test {yaml=analysis-common/50_char_filters/pattern_replace error handling (too complex pattern)}" + method: "test {yaml=analysis-common/50_char_filters/pattern_replace error handling\ + \ (too complex pattern)}" - class: "org.elasticsearch.xpack.ml.integration.ClassificationHousePricingIT" issue: "https://github.com/elastic/elasticsearch/issues/101598" method: "testFeatureImportanceValues" @@ -44,21 +47,9 @@ tests: - class: "org.elasticsearch.xpack.test.rest.XPackRestIT" issue: "https://github.com/elastic/elasticsearch/issues/109687" method: "test {p0=sql/translate/Translate SQL}" -- class: org.elasticsearch.action.search.SearchProgressActionListenerIT - method: testSearchProgressWithHits - issue: https://github.com/elastic/elasticsearch/issues/109830 -- class: "org.elasticsearch.xpack.security.ScrollHelperIntegTests" - issue: "https://github.com/elastic/elasticsearch/issues/109905" - method: "testFetchAllEntities" -- class: "org.elasticsearch.xpack.esql.action.AsyncEsqlQueryActionIT" - issue: "https://github.com/elastic/elasticsearch/issues/109944" - method: "testBasicAsyncExecution" - class: "org.elasticsearch.action.admin.indices.rollover.RolloverIT" issue: "https://github.com/elastic/elasticsearch/issues/110034" method: "testRolloverWithClosedWriteIndex" -- class: org.elasticsearch.xpack.transform.transforms.TransformIndexerTests - method: testMaxPageSearchSizeIsResetToConfiguredValue - issue: https://github.com/elastic/elasticsearch/issues/109844 - class: org.elasticsearch.index.store.FsDirectoryFactoryTests method: testStoreDirectory issue: https://github.com/elastic/elasticsearch/issues/110210 @@ -67,24 +58,12 @@ tests: issue: https://github.com/elastic/elasticsearch/issues/110211 - class: "org.elasticsearch.rest.RestControllerIT" issue: "https://github.com/elastic/elasticsearch/issues/110225" -- class: "org.elasticsearch.xpack.security.authz.store.NativePrivilegeStoreCacheTests" - issue: "https://github.com/elastic/elasticsearch/issues/110227" - method: "testGetPrivilegesUsesCache" - class: org.elasticsearch.upgrades.SecurityIndexRolesMetadataMigrationIT method: testMetadataMigratedAfterUpgrade issue: https://github.com/elastic/elasticsearch/issues/110232 - class: org.elasticsearch.compute.lucene.ValueSourceReaderTypeConversionTests method: testLoadAll issue: https://github.com/elastic/elasticsearch/issues/110244 -- class: org.elasticsearch.painless.LangPainlessClientYamlTestSuiteIT - method: test {yaml=painless/146_dense_vector_bit_basic/Cosine Similarity is not supported} - issue: https://github.com/elastic/elasticsearch/issues/110290 -- class: org.elasticsearch.painless.LangPainlessClientYamlTestSuiteIT - method: test {yaml=painless/146_dense_vector_bit_basic/Dot Product is not supported} - issue: https://github.com/elastic/elasticsearch/issues/110291 -- class: org.elasticsearch.action.search.SearchProgressActionListenerIT - method: testSearchProgressWithQuery - issue: https://github.com/elastic/elasticsearch/issues/109867 - class: org.elasticsearch.backwards.SearchWithMinCompatibleSearchNodeIT method: testMinVersionAsNewVersion issue: https://github.com/elastic/elasticsearch/issues/95384 @@ -94,30 +73,15 @@ tests: - class: org.elasticsearch.backwards.SearchWithMinCompatibleSearchNodeIT method: testMinVersionAsOldVersion issue: https://github.com/elastic/elasticsearch/issues/109454 -- class: org.elasticsearch.xpack.esql.tree.EsqlNodeSubclassTests - method: testReplaceChildren {class org.elasticsearch.xpack.esql.expression.function.aggregate.ToPartial} - issue: https://github.com/elastic/elasticsearch/issues/110310 -- class: org.elasticsearch.xpack.esql.tree.EsqlNodeSubclassTests - method: testInfoParameters {class org.elasticsearch.xpack.esql.expression.function.aggregate.ToPartial} - issue: https://github.com/elastic/elasticsearch/issues/110310 -- class: org.elasticsearch.search.vectors.ExactKnnQueryBuilderTests - method: testToQuery - issue: https://github.com/elastic/elasticsearch/issues/110357 -- class: org.elasticsearch.search.aggregations.bucket.terms.RareTermsIT - method: testSingleValuedString - issue: https://github.com/elastic/elasticsearch/issues/110388 - class: "org.elasticsearch.xpack.searchablesnapshots.FrozenSearchableSnapshotsIntegTests" issue: "https://github.com/elastic/elasticsearch/issues/110408" method: "testCreateAndRestorePartialSearchableSnapshot" -- class: "org.elasticsearch.xpack.security.role.RoleWithDescriptionRestIT" - issue: "https://github.com/elastic/elasticsearch/issues/110416" - method: "testCreateOrUpdateRoleWithDescription" -- class: "org.elasticsearch.xpack.security.role.RoleWithDescriptionRestIT" - issue: "https://github.com/elastic/elasticsearch/issues/110417" - method: "testCreateOrUpdateRoleWithDescription" -- class: org.elasticsearch.test.rest.yaml.CcsCommonYamlTestSuiteIT - method: test {p0=search.vectors/41_knn_search_half_byte_quantized/Test create, merge, and search cosine} - issue: https://github.com/elastic/elasticsearch/issues/109978 +- class: "org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT" + issue: "https://github.com/elastic/elasticsearch/issues/110719" + method: "test {p0=search.vectors/45_knn_search_byte/Test nonexistent field}" +- class: "org.elasticsearch.backwards.MixedClusterClientYamlTestSuiteIT" + issue: "https://github.com/elastic/elasticsearch/issues/110720" + method: "test {p0=search.vectors/40_knn_search/Test nonexistent field}" # Examples: # diff --git a/qa/ccs-rolling-upgrade-remote-cluster/build.gradle b/qa/ccs-rolling-upgrade-remote-cluster/build.gradle index c48674831c422..b63522daa4b4c 100644 --- a/qa/ccs-rolling-upgrade-remote-cluster/build.gradle +++ b/qa/ccs-rolling-upgrade-remote-cluster/build.gradle @@ -58,7 +58,11 @@ BuildParams.bwcVersions.withWireCompatible { bwcVersion, baseName -> dependsOn "processTestResources" mustRunAfter("precommit") doFirst { - localCluster.get().nextNodeToNextVersion() + def cluster = localCluster.get() + cluster.nodes.forEach { node -> + node.getAllTransportPortURI() + } + cluster.nextNodeToNextVersion() } } diff --git a/qa/mixed-cluster/src/test/java/org/elasticsearch/backwards/IndexingIT.java b/qa/mixed-cluster/src/test/java/org/elasticsearch/backwards/IndexingIT.java index aac4b6a020d4b..6c924fe8e429a 100644 --- a/qa/mixed-cluster/src/test/java/org/elasticsearch/backwards/IndexingIT.java +++ b/qa/mixed-cluster/src/test/java/org/elasticsearch/backwards/IndexingIT.java @@ -59,20 +59,13 @@ private int indexDocs(String index, final int idStart, final int numDocs) throws */ private int indexDocWithConcurrentUpdates(String index, final int docId, int nUpdates) throws IOException, InterruptedException { indexDocs(index, docId, 1); - Thread[] indexThreads = new Thread[nUpdates]; - for (int i = 0; i < nUpdates; i++) { - indexThreads[i] = new Thread(() -> { - try { - indexDocs(index, docId, 1); - } catch (IOException e) { - throw new AssertionError("failed while indexing [" + e.getMessage() + "]"); - } - }); - indexThreads[i].start(); - } - for (Thread indexThread : indexThreads) { - indexThread.join(); - } + runInParallel(nUpdates, i -> { + try { + indexDocs(index, docId, 1); + } catch (IOException e) { + throw new AssertionError("failed while indexing [" + e.getMessage() + "]"); + } + }); return nUpdates + 1; } diff --git a/qa/packaging/build.gradle b/qa/packaging/build.gradle index 758dfe6661766..02bc30ecd6b39 100644 --- a/qa/packaging/build.gradle +++ b/qa/packaging/build.gradle @@ -36,3 +36,13 @@ tasks.named("test").configure { enabled = false } tasks.register('destructivePackagingTest') { dependsOn 'destructiveDistroTest' } + +tasks.named('resolveAllDependencies') { + // Don't try and resolve all distros but only the latest patch versions of each minor + def latestBugfixVersions = org.elasticsearch.gradle.internal.info.BuildParams.getBwcVersions().getIndexCompatible() + .groupBy { [it.major, it.minor] } + .collectEntries { key, value -> [key, value.max()] } + .values() + + configs = configurations.matching { configName -> latestBugfixVersions.any { v -> configName.name.endsWith(v.toString()) } } +} diff --git a/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/FileSettingsUpgradeIT.java b/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/FileSettingsUpgradeIT.java new file mode 100644 index 0000000000000..c80911fe5fbcf --- /dev/null +++ b/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/upgrades/FileSettingsUpgradeIT.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.upgrades; + +import com.carrotsearch.randomizedtesting.annotations.Name; + +import org.elasticsearch.client.Request; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.core.SuppressForbidden; +import org.elasticsearch.test.cluster.ElasticsearchCluster; +import org.elasticsearch.test.cluster.FeatureFlag; +import org.elasticsearch.test.cluster.local.DefaultLocalClusterSpecBuilder; +import org.elasticsearch.test.cluster.local.distribution.DistributionType; +import org.elasticsearch.test.cluster.util.Version; +import org.elasticsearch.test.cluster.util.resource.Resource; +import org.junit.BeforeClass; +import org.junit.ClassRule; +import org.junit.rules.RuleChain; +import org.junit.rules.TemporaryFolder; +import org.junit.rules.TestRule; + +import java.io.IOException; +import java.util.Map; +import java.util.function.Supplier; + +import static org.hamcrest.Matchers.equalTo; + +public class FileSettingsUpgradeIT extends ParameterizedRollingUpgradeTestCase { + + @BeforeClass + public static void checkVersion() { + assumeTrue("Only valid when upgrading from pre-file settings", getOldClusterTestVersion().before(new Version(8, 4, 0))); + } + + private static final String settingsJSON = """ + { + "metadata": { + "version": "1", + "compatibility": "8.4.0" + }, + "state": { + "cluster_settings": { + "indices.recovery.max_bytes_per_sec": "50mb" + } + } + }"""; + + private static final TemporaryFolder repoDirectory = new TemporaryFolder(); + + private static final ElasticsearchCluster cluster = new DefaultLocalClusterSpecBuilder().distribution(DistributionType.DEFAULT) + .version(getOldClusterTestVersion()) + .nodes(NODE_NUM) + .setting("path.repo", new Supplier<>() { + @Override + @SuppressForbidden(reason = "TemporaryFolder only has io.File methods, not nio.File") + public String get() { + return repoDirectory.getRoot().getPath(); + } + }) + .setting("xpack.security.enabled", "false") + .feature(FeatureFlag.TIME_SERIES_MODE) + .configFile("operator/settings.json", Resource.fromString(settingsJSON)) + .build(); + + @ClassRule + public static TestRule ruleChain = RuleChain.outerRule(repoDirectory).around(cluster); + + public FileSettingsUpgradeIT(@Name("upgradedNodes") int upgradedNodes) { + super(upgradedNodes); + } + + @Override + protected ElasticsearchCluster getUpgradeCluster() { + return cluster; + } + + public void testFileSettingsApplied() throws IOException { + if (isUpgradedCluster()) { + // the nodes have all been upgraded. Check they read the file settings ok + Map response = responseAsMap(adminClient().performRequest(new Request("GET", "/_cluster/settings"))); + assertThat(XContentMapValues.extractValue(response, "persistent", "indices", "recovery", "max_bytes_per_sec"), equalTo("50mb")); + } + } +} diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/capabilities.json b/rest-api-spec/src/main/resources/rest-api-spec/api/capabilities.json index 28c341d9983cc..a96be0d63834e 100644 --- a/rest-api-spec/src/main/resources/rest-api-spec/api/capabilities.json +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/capabilities.json @@ -1,7 +1,7 @@ { "capabilities": { "documentation": { - "url": "https://www.elastic.co/guide/en/elasticsearch/reference/master/capabilities.html", + "url": "https://github.com/elastic/elasticsearch/blob/main/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/README.asciidoc#require-or-skip-api-capabilities", "description": "Checks if the specified combination of method, API, parameters, and arbitrary capabilities are supported" }, "stability": "experimental", diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml index 7f0c24e217d14..825bcecf33fce 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/40_knn_search.yml @@ -287,6 +287,9 @@ setup: - requires: cluster_features: "gte_v8.4.0" reason: 'kNN added to search endpoint in 8.4' + - skip: + cluster_features: "gte_v8.16.0" + reason: 'non-existent field handling improved in 8.16' - do: catch: bad_request search: @@ -298,9 +301,28 @@ setup: query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] k: 2 num_candidates: 3 + - match: { error.root_cause.0.type: "query_shard_exception" } - match: { error.root_cause.0.reason: "failed to create query: field [nonexistent] does not exist in the mapping" } +--- +"Test nonexistent field is match none": + - requires: + cluster_features: "gte_v8.16.0" + reason: 'non-existent field handling improved in 8.16' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: nonexistent + query_vector: [ -0.5, 90.0, -10, 14.8, -156.0 ] + k: 2 + num_candidates: 3 + + - length: {hits.hits: 0} + --- "KNN Vector similarity search only": - requires: diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml index cb5aae482507a..5f1af2ca5c52f 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/41_knn_search_half_byte_quantized.yml @@ -428,7 +428,7 @@ setup: index: hnsw_byte_quantized_merge_cosine id: "1" body: - embedding: [1.0, 1.0, 1.0, 1.0] + embedding: [0.5, 0.5, 0.5, 0.5, 0.5, 1.0] # Flush in order to provoke a merge later - do: @@ -439,7 +439,7 @@ setup: index: hnsw_byte_quantized_merge_cosine id: "2" body: - embedding: [1.0, 1.0, 1.0, 2.0] + embedding: [0.0, 0.0, 0.0, 1.0, 1.0, 0.5] # Flush in order to provoke a merge later - do: @@ -450,7 +450,7 @@ setup: index: hnsw_byte_quantized_merge_cosine id: "3" body: - embedding: [1.0, 1.0, 1.0, 3.0] + embedding: [0.0, 0.0, 0.0, 0.0, 0.0, 10.5] - do: indices.forcemerge: @@ -468,7 +468,7 @@ setup: query: knn: field: embedding - query_vector: [1.0, 1.0, 1.0, 1.0] + query_vector: [1.0, 1.0, 1.0, 1.0, 1.0, 1.0] num_candidates: 10 - length: { hits.hits: 3 } diff --git a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml index 983ac2719e71b..806e5ff73b355 100644 --- a/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml +++ b/rest-api-spec/src/yamlRestTest/resources/rest-api-spec/test/search.vectors/45_knn_search_byte.yml @@ -148,6 +148,9 @@ setup: --- "Test nonexistent field": + - skip: + cluster_features: 'gte_v8.16.0' + reason: 'non-existent field handling improved in 8.16' - do: catch: bad_request search: @@ -159,8 +162,26 @@ setup: query_vector: [ 1, 0, 0, 0, -1 ] k: 2 num_candidates: 3 + - match: { error.root_cause.0.type: "query_shard_exception" } - match: { error.root_cause.0.reason: "failed to create query: field [nonexistent] does not exist in the mapping" } +--- +"Test nonexistent field is match none": + - requires: + cluster_features: 'gte_v8.16.0' + reason: 'non-existent field handling improved in 8.16' + - do: + search: + index: test + body: + fields: [ "name" ] + knn: + field: nonexistent + query_vector: [ 1, 0, 0, 0, -1 ] + k: 2 + num_candidates: 3 + + - length: {hits.hits: 0} --- "Vector similarity search only": diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java index 4ad2a56d2e979..32d8be475dbbe 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/cluster/node/tasks/TasksIT.java @@ -546,13 +546,7 @@ public void testListTasksWaitForCompletion() throws Exception { // This ensures that a task has progressed to the point of listing all running tasks and subscribing to their updates for (var threadPool : internalCluster().getInstances(ThreadPool.class)) { - var max = threadPool.info(ThreadPool.Names.MANAGEMENT).getMax(); - var executor = threadPool.executor(ThreadPool.Names.MANAGEMENT); - var waitForManagementToCompleteAllTasks = new CyclicBarrier(max + 1); - for (int i = 0; i < max; i++) { - executor.submit(() -> safeAwait(waitForManagementToCompleteAllTasks)); - } - safeAwait(waitForManagementToCompleteAllTasks); + flushThreadPoolExecutor(threadPool, ThreadPool.Names.MANAGEMENT); } return future; diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/indices/rollover/RolloverIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/indices/rollover/RolloverIT.java index 48f1ecb072314..4d52383bfc4e1 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/admin/indices/rollover/RolloverIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/admin/indices/rollover/RolloverIT.java @@ -832,30 +832,22 @@ public void testRolloverConcurrently() throws Exception { assertAcked(client().execute(TransportPutComposableIndexTemplateAction.TYPE, putTemplateRequest).actionGet()); final CyclicBarrier barrier = new CyclicBarrier(numOfThreads); - final Thread[] threads = new Thread[numOfThreads]; - for (int i = 0; i < numOfThreads; i++) { + runInParallel(numOfThreads, i -> { var aliasName = "test-" + i; - threads[i] = new Thread(() -> { - assertAcked(prepareCreate(aliasName + "-000001").addAlias(new Alias(aliasName).writeIndex(true)).get()); - for (int j = 1; j <= numberOfRolloversPerThread; j++) { - try { - barrier.await(); - } catch (Exception e) { - throw new RuntimeException(e); - } - var response = indicesAdmin().prepareRolloverIndex(aliasName).waitForActiveShards(ActiveShardCount.NONE).get(); - assertThat(response.getOldIndex(), equalTo(aliasName + Strings.format("-%06d", j))); - assertThat(response.getNewIndex(), equalTo(aliasName + Strings.format("-%06d", j + 1))); - assertThat(response.isDryRun(), equalTo(false)); - assertThat(response.isRolledOver(), equalTo(true)); + assertAcked(prepareCreate(aliasName + "-000001").addAlias(new Alias(aliasName).writeIndex(true)).get()); + for (int j = 1; j <= numberOfRolloversPerThread; j++) { + try { + barrier.await(); + } catch (Exception e) { + throw new RuntimeException(e); } - }); - threads[i].start(); - } - - for (Thread thread : threads) { - thread.join(); - } + var response = indicesAdmin().prepareRolloverIndex(aliasName).waitForActiveShards(ActiveShardCount.NONE).get(); + assertThat(response.getOldIndex(), equalTo(aliasName + Strings.format("-%06d", j))); + assertThat(response.getNewIndex(), equalTo(aliasName + Strings.format("-%06d", j + 1))); + assertThat(response.isDryRun(), equalTo(false)); + assertThat(response.isRolledOver(), equalTo(true)); + } + }); for (int i = 0; i < numOfThreads; i++) { var aliasName = "test-" + i; diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/BulkAfterWriteFsyncFailureIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/BulkAfterWriteFsyncFailureIT.java index 5adc0b090ed37..6a4e973d8fcc5 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/BulkAfterWriteFsyncFailureIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/BulkAfterWriteFsyncFailureIT.java @@ -48,6 +48,7 @@ public static void removeDisruptFSyncFS() { PathUtilsForTesting.teardown(); } + @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/110551") public void testFsyncFailureDoesNotAdvanceLocalCheckpoints() { String indexName = randomIdentifier(); client().admin() diff --git a/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/BulkWithUpdatesIT.java b/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/BulkWithUpdatesIT.java index 00bd6ee7ee891..5251f171150b7 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/BulkWithUpdatesIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/action/bulk/BulkWithUpdatesIT.java @@ -39,7 +39,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; -import java.util.concurrent.CyclicBarrier; import java.util.function.Function; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; @@ -518,34 +517,17 @@ public void testFailingVersionedUpdatedOnBulk() throws Exception { createIndex("test"); indexDoc("test", "1", "field", "1"); final BulkResponse[] responses = new BulkResponse[30]; - final CyclicBarrier cyclicBarrier = new CyclicBarrier(responses.length); - Thread[] threads = new Thread[responses.length]; - - for (int i = 0; i < responses.length; i++) { - final int threadID = i; - threads[threadID] = new Thread(() -> { - try { - cyclicBarrier.await(); - } catch (Exception e) { - return; - } - BulkRequestBuilder requestBuilder = client().prepareBulk(); - requestBuilder.add( - client().prepareUpdate("test", "1") - .setIfSeqNo(0L) - .setIfPrimaryTerm(1) - .setDoc(Requests.INDEX_CONTENT_TYPE, "field", threadID) - ); - responses[threadID] = requestBuilder.get(); - }); - threads[threadID].start(); - - } - - for (int i = 0; i < threads.length; i++) { - threads[i].join(); - } + startInParallel(responses.length, threadID -> { + BulkRequestBuilder requestBuilder = client().prepareBulk(); + requestBuilder.add( + client().prepareUpdate("test", "1") + .setIfSeqNo(0L) + .setIfPrimaryTerm(1) + .setDoc(Requests.INDEX_CONTENT_TYPE, "field", threadID) + ); + responses[threadID] = requestBuilder.get(); + }); int successes = 0; for (BulkResponse response : responses) { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/blocks/SimpleBlocksIT.java b/server/src/internalClusterTest/java/org/elasticsearch/blocks/SimpleBlocksIT.java index 136db24767d22..c5c3e441363da 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/blocks/SimpleBlocksIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/blocks/SimpleBlocksIT.java @@ -32,6 +32,7 @@ import java.util.List; import java.util.Locale; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; import java.util.function.Consumer; import java.util.stream.IntStream; @@ -310,7 +311,7 @@ public void testAddBlockToUnassignedIndex() throws Exception { } } - public void testConcurrentAddBlock() throws InterruptedException { + public void testConcurrentAddBlock() throws InterruptedException, ExecutionException { final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); createIndex(indexName); @@ -322,31 +323,19 @@ public void testConcurrentAddBlock() throws InterruptedException { IntStream.range(0, nbDocs).mapToObj(i -> prepareIndex(indexName).setId(String.valueOf(i)).setSource("num", i)).collect(toList()) ); ensureYellowAndNoInitializingShards(indexName); - - final CountDownLatch startClosing = new CountDownLatch(1); - final Thread[] threads = new Thread[randomIntBetween(2, 5)]; - final APIBlock block = randomAddableBlock(); + final int threadCount = randomIntBetween(2, 5); try { - for (int i = 0; i < threads.length; i++) { - threads[i] = new Thread(() -> { - safeAwait(startClosing); - try { - indicesAdmin().prepareAddBlock(block, indexName).get(); - assertIndexHasBlock(block, indexName); - } catch (final ClusterBlockException e) { - assertThat(e.blocks(), hasSize(1)); - assertTrue(e.blocks().stream().allMatch(b -> b.id() == block.getBlock().id())); - } - }); - threads[i].start(); - } - - startClosing.countDown(); - for (Thread thread : threads) { - thread.join(); - } + startInParallel(threadCount, i -> { + try { + indicesAdmin().prepareAddBlock(block, indexName).get(); + assertIndexHasBlock(block, indexName); + } catch (final ClusterBlockException e) { + assertThat(e.blocks(), hasSize(1)); + assertTrue(e.blocks().stream().allMatch(b -> b.id() == block.getBlock().id())); + } + }); assertIndexHasBlock(block, indexName); } finally { disableIndexBlock(indexName, block); @@ -422,34 +411,17 @@ public void testAddBlockWhileDeletingIndices() throws Exception { }; try { - for (final String indexToDelete : indices) { - threads.add(new Thread(() -> { - safeAwait(latch); - try { - assertAcked(indicesAdmin().prepareDelete(indexToDelete)); - } catch (final Exception e) { - exceptionConsumer.accept(e); - } - })); - } - for (final String indexToBlock : indices) { - threads.add(new Thread(() -> { - safeAwait(latch); - try { - indicesAdmin().prepareAddBlock(block, indexToBlock).get(); - } catch (final Exception e) { - exceptionConsumer.accept(e); + startInParallel(indices.length * 2, i -> { + try { + if (i < indices.length) { + assertAcked(indicesAdmin().prepareDelete(indices[i])); + } else { + indicesAdmin().prepareAddBlock(block, indices[i - indices.length]).get(); } - })); - } - - for (Thread thread : threads) { - thread.start(); - } - latch.countDown(); - for (Thread thread : threads) { - thread.join(); - } + } catch (final Exception e) { + exceptionConsumer.accept(e); + } + }); } finally { for (final String indexToBlock : indices) { try { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/gateway/ReplicaShardAllocatorSyncIdIT.java b/server/src/internalClusterTest/java/org/elasticsearch/gateway/ReplicaShardAllocatorSyncIdIT.java index 27e63e5614744..13886cba9084c 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/gateway/ReplicaShardAllocatorSyncIdIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/gateway/ReplicaShardAllocatorSyncIdIT.java @@ -100,8 +100,8 @@ void syncFlush(String syncId) throws IOException { assertThat(getTranslogStats().getUncommittedOperations(), equalTo(0)); Map userData = new HashMap<>(getLastCommittedSegmentInfos().userData); SequenceNumbers.CommitInfo commitInfo = SequenceNumbers.loadSeqNoInfoFromLuceneCommit(userData.entrySet()); - assertThat(commitInfo.localCheckpoint, equalTo(getLastSyncedGlobalCheckpoint())); - assertThat(commitInfo.maxSeqNo, equalTo(getLastSyncedGlobalCheckpoint())); + assertThat(commitInfo.localCheckpoint(), equalTo(getLastSyncedGlobalCheckpoint())); + assertThat(commitInfo.maxSeqNo(), equalTo(getLastSyncedGlobalCheckpoint())); userData.put(Engine.SYNC_COMMIT_ID, syncId); indexWriter.setLiveCommitData(userData.entrySet()); indexWriter.commit(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/engine/MaxDocsLimitIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/engine/MaxDocsLimitIT.java index acfc38ca12f89..be7610e55b8e6 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/engine/MaxDocsLimitIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/engine/MaxDocsLimitIT.java @@ -26,7 +26,6 @@ import java.util.Collection; import java.util.Optional; -import java.util.concurrent.Phaser; import java.util.concurrent.atomic.AtomicInteger; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertAcked; @@ -103,7 +102,7 @@ public void testMaxDocsLimit() throws Exception { assertThat(indexingResult.numFailures, equalTo(rejectedRequests)); assertThat(indexingResult.numSuccess, equalTo(0)); final IllegalArgumentException deleteError = expectThrows(IllegalArgumentException.class, client().prepareDelete("test", "any-id")); - assertThat(deleteError.getMessage(), containsString("Number of documents in the index can't exceed [" + maxDocs.get() + "]")); + assertThat(deleteError.getMessage(), containsString("Number of documents in the shard cannot exceed [" + maxDocs.get() + "]")); indicesAdmin().prepareRefresh("test").get(); assertNoFailuresAndResponse( prepareSearch("test").setQuery(new MatchAllQueryBuilder()).setTrackTotalHitsUpTo(Integer.MAX_VALUE).setSize(0), @@ -155,27 +154,18 @@ static IndexingResult indexDocs(int numRequests, int numThreads) throws Exceptio final AtomicInteger completedRequests = new AtomicInteger(); final AtomicInteger numSuccess = new AtomicInteger(); final AtomicInteger numFailure = new AtomicInteger(); - Thread[] indexers = new Thread[numThreads]; - Phaser phaser = new Phaser(indexers.length); - for (int i = 0; i < indexers.length; i++) { - indexers[i] = new Thread(() -> { - phaser.arriveAndAwaitAdvance(); - while (completedRequests.incrementAndGet() <= numRequests) { - try { - final DocWriteResponse resp = prepareIndex("test").setSource("{}", XContentType.JSON).get(); - numSuccess.incrementAndGet(); - assertThat(resp.status(), equalTo(RestStatus.CREATED)); - } catch (IllegalArgumentException e) { - numFailure.incrementAndGet(); - assertThat(e.getMessage(), containsString("Number of documents in the index can't exceed [" + maxDocs.get() + "]")); - } + startInParallel(numThreads, i -> { + while (completedRequests.incrementAndGet() <= numRequests) { + try { + final DocWriteResponse resp = prepareIndex("test").setSource("{}", XContentType.JSON).get(); + numSuccess.incrementAndGet(); + assertThat(resp.status(), equalTo(RestStatus.CREATED)); + } catch (IllegalArgumentException e) { + numFailure.incrementAndGet(); + assertThat(e.getMessage(), containsString("Number of documents in the shard cannot exceed [" + maxDocs.get() + "]")); } - }); - indexers[i].start(); - } - for (Thread indexer : indexers) { - indexer.join(); - } + } + }); internalCluster().assertNoInFlightDocsInEngine(); return new IndexingResult(numSuccess.get(), numFailure.get()); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/mapper/DynamicMappingIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/mapper/DynamicMappingIT.java index 76d305ce8ea4b..3f79d7723beb3 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/mapper/DynamicMappingIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/mapper/DynamicMappingIT.java @@ -161,31 +161,18 @@ public void testConcurrentDynamicIgnoreBeyondLimitUpdates() throws Throwable { private Map indexConcurrently(int numberOfFieldsToCreate, Settings.Builder settings) throws Throwable { indicesAdmin().prepareCreate("index").setSettings(settings).get(); ensureGreen("index"); - final Thread[] indexThreads = new Thread[numberOfFieldsToCreate]; - final CountDownLatch startLatch = new CountDownLatch(1); final AtomicReference error = new AtomicReference<>(); - for (int i = 0; i < indexThreads.length; ++i) { + startInParallel(numberOfFieldsToCreate, i -> { final String id = Integer.toString(i); - indexThreads[i] = new Thread(new Runnable() { - @Override - public void run() { - try { - startLatch.await(); - assertEquals( - DocWriteResponse.Result.CREATED, - prepareIndex("index").setId(id).setSource("field" + id, "bar").get().getResult() - ); - } catch (Exception e) { - error.compareAndSet(null, e); - } - } - }); - indexThreads[i].start(); - } - startLatch.countDown(); - for (Thread thread : indexThreads) { - thread.join(); - } + try { + assertEquals( + DocWriteResponse.Result.CREATED, + prepareIndex("index").setId(id).setSource("field" + id, "bar").get().getResult() + ); + } catch (Exception e) { + error.compareAndSet(null, e); + } + }); if (error.get() != null) { throw error.get(); } diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/seqno/GlobalCheckpointSyncIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/seqno/GlobalCheckpointSyncIT.java index c60b6bb72e8ed..53f632f6ba8d5 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/seqno/GlobalCheckpointSyncIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/seqno/GlobalCheckpointSyncIT.java @@ -25,11 +25,7 @@ import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.xcontent.XContentType; -import java.util.ArrayList; import java.util.Collection; -import java.util.List; -import java.util.concurrent.BrokenBarrierException; -import java.util.concurrent.CyclicBarrier; import java.util.concurrent.TimeUnit; import java.util.function.Consumer; import java.util.stream.Stream; @@ -143,37 +139,14 @@ private void runGlobalCheckpointSyncTest( final int numberOfDocuments = randomIntBetween(0, 256); final int numberOfThreads = randomIntBetween(1, 4); - final CyclicBarrier barrier = new CyclicBarrier(1 + numberOfThreads); // start concurrent indexing threads - final List threads = new ArrayList<>(numberOfThreads); - for (int i = 0; i < numberOfThreads; i++) { - final int index = i; - final Thread thread = new Thread(() -> { - try { - barrier.await(); - } catch (BrokenBarrierException | InterruptedException e) { - throw new RuntimeException(e); - } - for (int j = 0; j < numberOfDocuments; j++) { - final String id = Integer.toString(index * numberOfDocuments + j); - prepareIndex("test").setId(id).setSource("{\"foo\": " + id + "}", XContentType.JSON).get(); - } - try { - barrier.await(); - } catch (BrokenBarrierException | InterruptedException e) { - throw new RuntimeException(e); - } - }); - threads.add(thread); - thread.start(); - } - - // synchronize the start of the threads - barrier.await(); - - // wait for the threads to finish - barrier.await(); + startInParallel(numberOfThreads, index -> { + for (int j = 0; j < numberOfDocuments; j++) { + final String id = Integer.toString(index * numberOfDocuments + j); + prepareIndex("test").setId(id).setSource("{\"foo\": " + id + "}", XContentType.JSON).get(); + } + }); afterIndexing.accept(client()); @@ -203,9 +176,6 @@ private void runGlobalCheckpointSyncTest( } }, 60, TimeUnit.SECONDS); ensureGreen("test"); - for (final Thread thread : threads) { - thread.join(); - } } public void testPersistGlobalCheckpoint() throws Exception { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java b/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java index b9850bc95275c..5d996e44c6868 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/index/shard/IndexShardIT.java @@ -164,7 +164,7 @@ public void testDurableFlagHasEffect() { try { // the lastWriteLocaltion has a Integer.MAX_VALUE size so we have to create a new one return tlog.ensureSynced( - new Translog.Location(lastWriteLocation.generation, lastWriteLocation.translogLocation, 0), + new Translog.Location(lastWriteLocation.generation(), lastWriteLocation.translogLocation(), 0), SequenceNumbers.UNASSIGNED_SEQ_NO ); } catch (IOException e) { @@ -389,7 +389,7 @@ public void testMaybeFlush() throws Exception { logger.info( "--> translog stats [{}] gen [{}] commit_stats [{}] flush_stats [{}/{}]", Strings.toString(translogStats), - translog.getGeneration().translogFileGeneration, + translog.getGeneration().translogFileGeneration(), commitStats.getUserData(), flushStats.getPeriodic(), flushStats.getTotal() @@ -428,7 +428,7 @@ public void testMaybeRollTranslogGeneration() throws Exception { ); final Translog.Location location = result.getTranslogLocation(); shard.afterWriteOperation(); - if (location.translogLocation + location.size > generationThreshold) { + if (location.translogLocation() + location.size() > generationThreshold) { // wait until the roll completes assertBusy(() -> assertFalse(shard.shouldRollTranslogGeneration())); rolls++; diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/IndicesLifecycleListenerIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/IndicesLifecycleListenerIT.java index b224d70eed8f8..e9e88a2d6b76c 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/IndicesLifecycleListenerIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/IndicesLifecycleListenerIT.java @@ -13,6 +13,7 @@ import org.elasticsearch.cluster.routing.ShardRoutingState; import org.elasticsearch.cluster.routing.allocation.command.MoveAllocationCommand; import org.elasticsearch.cluster.routing.allocation.decider.EnableAllocationDecider; +import org.elasticsearch.cluster.routing.allocation.decider.MaxRetryAllocationDecider; import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.CheckedRunnable; @@ -127,7 +128,7 @@ public void beforeIndexCreated(Index index, Settings indexSettings) { assertThat(state.nodes().get(shard.currentNodeId()).getName(), equalTo(node1)); } - public void testRelocationFailureNotRetriedForever() { + public void testRelocationFailureNotRetriedForever() throws Exception { String node1 = internalCluster().startNode(); createIndex("index1", 1, 0); ensureGreen("index1"); @@ -143,6 +144,16 @@ public void beforeIndexCreated(Index index, Settings indexSettings) { updateIndexSettings(Settings.builder().put(INDEX_ROUTING_EXCLUDE_GROUP_PREFIX + "._name", node1), "index1"); ensureGreen("index1"); + var maxAttempts = MaxRetryAllocationDecider.SETTING_ALLOCATION_MAX_RETRY.get(Settings.EMPTY); + + // await all relocation attempts are exhausted + assertBusy(() -> { + var state = clusterAdmin().prepareState().get().getState(); + var shard = state.routingTable().index("index1").shard(0).primaryShard(); + assertThat(shard, notNullValue()); + assertThat(shard.relocationFailureInfo().failedRelocations(), equalTo(maxAttempts)); + }); + // ensure the shard remain started var state = clusterAdmin().prepareState().get().getState(); logger.info("Final routing is {}", state.getRoutingNodes().toString()); var shard = state.routingTable().index("index1").shard(0).primaryShard(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/mapping/UpdateMappingIntegrationIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/mapping/UpdateMappingIntegrationIT.java index 70cd143686dc8..0008ec1f9cbd2 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/mapping/UpdateMappingIntegrationIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/mapping/UpdateMappingIntegrationIT.java @@ -37,7 +37,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.CyclicBarrier; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; @@ -179,66 +178,53 @@ public void testUpdateMappingConcurrently() throws Throwable { final AtomicReference threadException = new AtomicReference<>(); final AtomicBoolean stop = new AtomicBoolean(false); - Thread[] threads = new Thread[3]; - final CyclicBarrier barrier = new CyclicBarrier(threads.length); final ArrayList clientArray = new ArrayList<>(); for (Client c : clients()) { clientArray.add(c); } - for (int j = 0; j < threads.length; j++) { - threads[j] = new Thread(() -> { - try { - barrier.await(); - - for (int i = 0; i < 100; i++) { - if (stop.get()) { - return; - } - - Client client1 = clientArray.get(i % clientArray.size()); - Client client2 = clientArray.get((i + 1) % clientArray.size()); - String indexName = i % 2 == 0 ? "test2" : "test1"; - String fieldName = Thread.currentThread().getName() + "_" + i; - - AcknowledgedResponse response = client1.admin() - .indices() - .preparePutMapping(indexName) - .setSource( - JsonXContent.contentBuilder() - .startObject() - .startObject("_doc") - .startObject("properties") - .startObject(fieldName) - .field("type", "text") - .endObject() - .endObject() - .endObject() - .endObject() - ) - .setMasterNodeTimeout(TimeValue.timeValueMinutes(5)) - .get(); - - assertThat(response.isAcknowledged(), equalTo(true)); - GetMappingsResponse getMappingResponse = client2.admin().indices().prepareGetMappings(indexName).get(); - MappingMetadata mappings = getMappingResponse.getMappings().get(indexName); - @SuppressWarnings("unchecked") - Map properties = (Map) mappings.getSourceAsMap().get("properties"); - assertThat(properties.keySet(), Matchers.hasItem(fieldName)); + startInParallel(3, j -> { + try { + for (int i = 0; i < 100; i++) { + if (stop.get()) { + return; } - } catch (Exception e) { - threadException.set(e); - stop.set(true); - } - }); - - threads[j].setName("t_" + j); - threads[j].start(); - } - for (Thread t : threads) { - t.join(); - } + Client client1 = clientArray.get(i % clientArray.size()); + Client client2 = clientArray.get((i + 1) % clientArray.size()); + String indexName = i % 2 == 0 ? "test2" : "test1"; + String fieldName = "t_" + j + "_" + i; + + AcknowledgedResponse response = client1.admin() + .indices() + .preparePutMapping(indexName) + .setSource( + JsonXContent.contentBuilder() + .startObject() + .startObject("_doc") + .startObject("properties") + .startObject(fieldName) + .field("type", "text") + .endObject() + .endObject() + .endObject() + .endObject() + ) + .setMasterNodeTimeout(TimeValue.timeValueMinutes(5)) + .get(); + + assertThat(response.isAcknowledged(), equalTo(true)); + GetMappingsResponse getMappingResponse = client2.admin().indices().prepareGetMappings(indexName).get(); + MappingMetadata mappings = getMappingResponse.getMappings().get(indexName); + @SuppressWarnings("unchecked") + Map properties = (Map) mappings.getSourceAsMap().get("properties"); + assertThat(properties.keySet(), Matchers.hasItem(fieldName)); + } + } catch (Exception e) { + threadException.set(e); + stop.set(true); + } + }); if (threadException.get() != null) { throw threadException.get(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexRecoveryIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexRecoveryIT.java index 204d7131c44d2..d56e4a372c17c 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexRecoveryIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/recovery/IndexRecoveryIT.java @@ -1210,8 +1210,8 @@ public void testRecoverLocallyUpToGlobalCheckpoint() throws Exception { SequenceNumbers.CommitInfo commitInfoAfterLocalRecovery = SequenceNumbers.loadSeqNoInfoFromLuceneCommit( startRecoveryRequest.metadataSnapshot().commitUserData().entrySet() ); - assertThat(commitInfoAfterLocalRecovery.localCheckpoint, equalTo(lastSyncedGlobalCheckpoint)); - assertThat(commitInfoAfterLocalRecovery.maxSeqNo, equalTo(lastSyncedGlobalCheckpoint)); + assertThat(commitInfoAfterLocalRecovery.localCheckpoint(), equalTo(lastSyncedGlobalCheckpoint)); + assertThat(commitInfoAfterLocalRecovery.maxSeqNo(), equalTo(lastSyncedGlobalCheckpoint)); assertThat(startRecoveryRequest.startingSeqNo(), equalTo(lastSyncedGlobalCheckpoint + 1)); ensureGreen(indexName); assertThat((long) localRecoveredOps.get(), equalTo(lastSyncedGlobalCheckpoint - localCheckpointOfSafeCommit)); @@ -2011,8 +2011,8 @@ private long getLocalCheckpointOfSafeCommit(IndexCommit safeIndexCommit) throws final SequenceNumbers.CommitInfo commitInfo = SequenceNumbers.loadSeqNoInfoFromLuceneCommit( safeIndexCommit.getUserData().entrySet() ); - final long commitLocalCheckpoint = commitInfo.localCheckpoint; - final long maxSeqNo = commitInfo.maxSeqNo; + final long commitLocalCheckpoint = commitInfo.localCheckpoint(); + final long maxSeqNo = commitInfo.maxSeqNo(); final LocalCheckpointTracker localCheckpointTracker = new LocalCheckpointTracker(maxSeqNo, commitLocalCheckpoint); // In certain scenarios it is possible that the local checkpoint captured during commit lags behind, diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/state/CloseIndexIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/state/CloseIndexIT.java index 77cdc2e99977d..d52294d7584b8 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/state/CloseIndexIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/state/CloseIndexIT.java @@ -38,12 +38,12 @@ import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.InternalTestCluster; -import java.util.ArrayList; import java.util.Arrays; import java.util.List; import java.util.Locale; import java.util.Set; -import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutionException; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -170,7 +170,7 @@ public void testCloseUnassignedIndex() throws Exception { assertIndexIsClosed(indexName); } - public void testConcurrentClose() throws InterruptedException { + public void testConcurrentClose() throws InterruptedException, ExecutionException { final String indexName = randomAlphaOfLength(10).toLowerCase(Locale.ROOT); createIndex(indexName); @@ -196,25 +196,14 @@ public void testConcurrentClose() throws InterruptedException { assertThat(healthResponse.isTimedOut(), equalTo(false)); assertThat(healthResponse.getIndices().get(indexName).getStatus().value(), lessThanOrEqualTo(ClusterHealthStatus.YELLOW.value())); - final CountDownLatch startClosing = new CountDownLatch(1); - final Thread[] threads = new Thread[randomIntBetween(2, 5)]; - - for (int i = 0; i < threads.length; i++) { - threads[i] = new Thread(() -> { - safeAwait(startClosing); - try { - indicesAdmin().prepareClose(indexName).get(); - } catch (final Exception e) { - assertException(e, indexName); - } - }); - threads[i].start(); - } - - startClosing.countDown(); - for (Thread thread : threads) { - thread.join(); - } + final int tasks = randomIntBetween(2, 5); + startInParallel(tasks, i -> { + try { + indicesAdmin().prepareClose(indexName).get(); + } catch (final Exception e) { + assertException(e, indexName); + } + }); assertIndexIsClosed(indexName); } @@ -256,37 +245,18 @@ public void testCloseWhileDeletingIndices() throws Exception { } assertThat(clusterAdmin().prepareState().get().getState().metadata().indices().size(), equalTo(indices.length)); - final List threads = new ArrayList<>(); - final CountDownLatch latch = new CountDownLatch(1); - - for (final String indexToDelete : indices) { - threads.add(new Thread(() -> { - safeAwait(latch); - try { - assertAcked(indicesAdmin().prepareDelete(indexToDelete)); - } catch (final Exception e) { - assertException(e, indexToDelete); - } - })); - } - for (final String indexToClose : indices) { - threads.add(new Thread(() -> { - safeAwait(latch); - try { - indicesAdmin().prepareClose(indexToClose).get(); - } catch (final Exception e) { - assertException(e, indexToClose); + startInParallel(indices.length * 2, i -> { + final String index = indices[i % indices.length]; + try { + if (i < indices.length) { + assertAcked(indicesAdmin().prepareDelete(index)); + } else { + indicesAdmin().prepareClose(index).get(); } - })); - } - - for (Thread thread : threads) { - thread.start(); - } - latch.countDown(); - for (Thread thread : threads) { - thread.join(); - } + } catch (final Exception e) { + assertException(e, index); + } + }); } public void testConcurrentClosesAndOpens() throws Exception { @@ -297,37 +267,21 @@ public void testConcurrentClosesAndOpens() throws Exception { indexer.setFailureAssertion(e -> {}); waitForDocs(1, indexer); - final CountDownLatch latch = new CountDownLatch(1); + final int closes = randomIntBetween(1, 3); + final int opens = randomIntBetween(1, 3); + final CyclicBarrier barrier = new CyclicBarrier(opens + closes); - final List threads = new ArrayList<>(); - for (int i = 0; i < randomIntBetween(1, 3); i++) { - threads.add(new Thread(() -> { - try { - safeAwait(latch); + startInParallel(opens + closes, i -> { + try { + if (i < closes) { indicesAdmin().prepareClose(indexName).get(); - } catch (final Exception e) { - throw new AssertionError(e); - } - })); - } - for (int i = 0; i < randomIntBetween(1, 3); i++) { - threads.add(new Thread(() -> { - try { - safeAwait(latch); + } else { assertAcked(indicesAdmin().prepareOpen(indexName).get()); - } catch (final Exception e) { - throw new AssertionError(e); } - })); - } - - for (Thread thread : threads) { - thread.start(); - } - latch.countDown(); - for (Thread thread : threads) { - thread.join(); - } + } catch (final Exception e) { + throw new AssertionError(e); + } + }); indexer.stopAndAwaitStopped(); diff --git a/server/src/internalClusterTest/java/org/elasticsearch/indices/state/CloseWhileRelocatingShardsIT.java b/server/src/internalClusterTest/java/org/elasticsearch/indices/state/CloseWhileRelocatingShardsIT.java index b160834d675d9..6647356f070ae 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/indices/state/CloseWhileRelocatingShardsIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/indices/state/CloseWhileRelocatingShardsIT.java @@ -35,7 +35,6 @@ import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; -import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.CountDownLatch; @@ -187,30 +186,17 @@ public void testCloseWhileRelocatingShards() throws Exception { ClusterRerouteUtils.reroute(client(), commands.toArray(AllocationCommand[]::new)); // start index closing threads - final List threads = new ArrayList<>(); - for (final String indexToClose : indices) { - final Thread thread = new Thread(() -> { - try { - safeAwait(latch); - } finally { - release.countDown(); - } - // Closing is not always acknowledged when shards are relocating: this is the case when the target shard is initializing - // or is catching up operations. In these cases the TransportVerifyShardBeforeCloseAction will detect that the global - // and max sequence number don't match and will not ack the close. - AcknowledgedResponse closeResponse = indicesAdmin().prepareClose(indexToClose).get(); - if (closeResponse.isAcknowledged()) { - assertTrue("Index closing should not be acknowledged twice", acknowledgedCloses.add(indexToClose)); - } - }); - threads.add(thread); - thread.start(); - } - - latch.countDown(); - for (Thread thread : threads) { - thread.join(); - } + startInParallel(indices.length, i -> { + release.countDown(); + // Closing is not always acknowledged when shards are relocating: this is the case when the target shard is initializing + // or is catching up operations. In these cases the TransportVerifyShardBeforeCloseAction will detect that the global + // and max sequence number don't match and will not ack the close. + final String indexToClose = indices[i]; + AcknowledgedResponse closeResponse = indicesAdmin().prepareClose(indexToClose).get(); + if (closeResponse.isAcknowledged()) { + assertTrue("Index closing should not be acknowledged twice", acknowledgedCloses.add(indexToClose)); + } + }); // stop indexers first without waiting for stop to not redundantly index on some while waiting for another one to stop for (BackgroundIndexer indexer : indexers.values()) { diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/geo/GeoBoundingBoxQueryGeoPointIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/geo/GeoBoundingBoxQueryGeoPointIT.java index e631d17fc480c..7a3b1699c30e5 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/geo/GeoBoundingBoxQueryGeoPointIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/search/geo/GeoBoundingBoxQueryGeoPointIT.java @@ -9,24 +9,14 @@ package org.elasticsearch.search.geo; import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.test.TestGeoShapeFieldMapperPlugin; import org.elasticsearch.test.index.IndexVersionUtils; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import java.io.IOException; -import java.util.Collection; -import java.util.Collections; public class GeoBoundingBoxQueryGeoPointIT extends GeoBoundingBoxQueryIntegTestCase { - @SuppressWarnings("deprecation") - @Override - protected Collection> nodePlugins() { - return Collections.singleton(TestGeoShapeFieldMapperPlugin.class); - } - @Override public XContentBuilder getMapping() throws IOException { XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/geo/GeoBoundingBoxQueryGeoShapeIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/geo/GeoBoundingBoxQueryGeoShapeIT.java deleted file mode 100644 index 2b310f6b0ea3e..0000000000000 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/geo/GeoBoundingBoxQueryGeoShapeIT.java +++ /dev/null @@ -1,46 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.search.geo; - -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.test.TestGeoShapeFieldMapperPlugin; -import org.elasticsearch.test.index.IndexVersionUtils; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; - -import java.io.IOException; -import java.util.Collection; -import java.util.Collections; - -public class GeoBoundingBoxQueryGeoShapeIT extends GeoBoundingBoxQueryIntegTestCase { - - @SuppressWarnings("deprecation") - @Override - protected Collection> nodePlugins() { - return Collections.singleton(TestGeoShapeFieldMapperPlugin.class); - } - - @Override - public XContentBuilder getMapping() throws IOException { - XContentBuilder xContentBuilder = XContentFactory.jsonBuilder() - .startObject() - .startObject("_doc") - .startObject("properties") - .startObject("location") - .field("type", "geo_shape"); - xContentBuilder.endObject().endObject().endObject().endObject(); - return xContentBuilder; - } - - @Override - public IndexVersion randomSupportedVersion() { - return IndexVersionUtils.randomCompatibleVersion(random()); - } -} diff --git a/server/src/internalClusterTest/java/org/elasticsearch/search/geo/GeoShapeIT.java b/server/src/internalClusterTest/java/org/elasticsearch/search/geo/GeoShapeIT.java deleted file mode 100644 index c165ed02984e6..0000000000000 --- a/server/src/internalClusterTest/java/org/elasticsearch/search/geo/GeoShapeIT.java +++ /dev/null @@ -1,43 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.search.geo; - -import org.elasticsearch.index.IndexVersion; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.test.TestGeoShapeFieldMapperPlugin; -import org.elasticsearch.test.index.IndexVersionUtils; -import org.elasticsearch.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.Collection; -import java.util.Collections; - -public class GeoShapeIT extends GeoShapeIntegTestCase { - - @SuppressWarnings("deprecation") - @Override - protected Collection> nodePlugins() { - return Collections.singleton(TestGeoShapeFieldMapperPlugin.class); - } - - @Override - protected void getGeoShapeMapping(XContentBuilder b) throws IOException { - b.field("type", "geo_shape"); - } - - @Override - protected IndexVersion randomSupportedVersion() { - return IndexVersionUtils.randomCompatibleVersion(random()); - } - - @Override - protected boolean allowExpensiveQueries() { - return true; - } -} diff --git a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/ConcurrentSnapshotsIT.java b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/ConcurrentSnapshotsIT.java index e03fafd5646e3..836bd26f08eee 100644 --- a/server/src/internalClusterTest/java/org/elasticsearch/snapshots/ConcurrentSnapshotsIT.java +++ b/server/src/internalClusterTest/java/org/elasticsearch/snapshots/ConcurrentSnapshotsIT.java @@ -47,7 +47,6 @@ import org.elasticsearch.test.ESIntegTestCase; import org.elasticsearch.test.disruption.NetworkDisruption; import org.elasticsearch.test.transport.MockTransportService; -import org.elasticsearch.transport.RemoteTransportException; import java.io.IOException; import java.nio.file.Files; @@ -788,18 +787,7 @@ public void testQueuedOperationsAndBrokenRepoOnMasterFailOver() throws Exception ensureStableCluster(3); awaitNoMoreRunningOperations(); - var innerException = expectThrows(ExecutionException.class, RuntimeException.class, deleteFuture::get); - - // There may be many layers of RTE to unwrap here, see https://github.com/elastic/elasticsearch/issues/102351. - // ExceptionsHelper#unwrapCause gives up at 10 layers of wrapping so we must unwrap more tenaciously by hand here: - while (true) { - if (innerException instanceof RemoteTransportException remoteTransportException) { - innerException = asInstanceOf(RuntimeException.class, remoteTransportException.getCause()); - } else { - assertThat(innerException, instanceOf(RepositoryException.class)); - break; - } - } + expectThrows(RepositoryException.class, deleteFuture::actionGet); } public void testQueuedSnapshotOperationsAndBrokenRepoOnMasterFailOver() throws Exception { diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index 2004c6fda8ce5..65606465b8502 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -208,6 +208,9 @@ static TransportVersion def(int id) { public static final TransportVersion TEXT_SIMILARITY_RERANKER_RETRIEVER = def(8_699_00_0); public static final TransportVersion ML_INFERENCE_GOOGLE_VERTEX_AI_RERANKING_ADDED = def(8_700_00_0); public static final TransportVersion VERSIONED_MASTER_NODE_REQUESTS = def(8_701_00_0); + public static final TransportVersion ML_INFERENCE_AMAZON_BEDROCK_ADDED = def(8_702_00_0); + public static final TransportVersion ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS = def(8_703_00_0); + public static final TransportVersion INFERENCE_ADAPTIVE_ALLOCATIONS = def(8_704_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplanation.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplanation.java index 1e5f9d5d613d2..abb4f478cff54 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplanation.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplanation.java @@ -16,6 +16,8 @@ import org.elasticsearch.cluster.routing.UnassignedInfo; import org.elasticsearch.cluster.routing.allocation.AllocationDecision; import org.elasticsearch.cluster.routing.allocation.ShardAllocationDecision; +import org.elasticsearch.common.ReferenceDocs; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -43,10 +45,14 @@ */ public final class ClusterAllocationExplanation implements ChunkedToXContentObject, Writeable { - static final String NO_SHARD_SPECIFIED_MESSAGE = "No shard was specified in the explain API request, so this response " - + "explains a randomly chosen unassigned shard. There may be other unassigned shards in this cluster which cannot be assigned for " - + "different reasons. It may not be possible to assign this shard until one of the other shards is assigned correctly. To explain " - + "the allocation of other shards (whether assigned or unassigned) you must specify the target shard in the request to this API."; + static final String NO_SHARD_SPECIFIED_MESSAGE = Strings.format( + """ + No shard was specified in the explain API request, so this response explains a randomly chosen unassigned shard. There may be \ + other unassigned shards in this cluster which cannot be assigned for different reasons. It may not be possible to assign this \ + shard until one of the other shards is assigned correctly. To explain the allocation of other shards (whether assigned or \ + unassigned) you must specify the target shard in the request to this API. See %s for more information.""", + ReferenceDocs.ALLOCATION_EXPLAIN_API + ); private final boolean specificShard; private final ShardRouting shardRouting; @@ -206,25 +212,23 @@ private Iterator getShardAllocationDecisionChunked(ToXCont } else { String explanation; if (shardRouting.state() == ShardRoutingState.RELOCATING) { - explanation = "the shard is in the process of relocating from node [" - + currentNode.getName() - + "] " - + "to node [" - + relocationTargetNode.getName() - + "], wait until relocation has completed"; + explanation = Strings.format( + "the shard is in the process of relocating from node [%s] to node [%s], wait until relocation has completed", + currentNode.getName(), + relocationTargetNode.getName() + ); } else { assert shardRouting.state() == ShardRoutingState.INITIALIZING; - explanation = "the shard is in the process of initializing on node [" - + currentNode.getName() - + "], " - + "wait until initialization has completed"; + explanation = Strings.format( + "the shard is in the process of initializing on node [%s], wait until initialization has completed", + currentNode.getName() + ); } return Iterators.single((builder, p) -> builder.field("explanation", explanation)); } } - private static XContentBuilder unassignedInfoToXContent(UnassignedInfo unassignedInfo, XContentBuilder builder) throws IOException { - + private static void unassignedInfoToXContent(UnassignedInfo unassignedInfo, XContentBuilder builder) throws IOException { builder.startObject("unassigned_info"); builder.field("reason", unassignedInfo.reason()); builder.field("at", UnassignedInfo.DATE_TIME_FORMATTER.format(Instant.ofEpochMilli(unassignedInfo.unassignedTimeMillis()))); @@ -237,6 +241,5 @@ private static XContentBuilder unassignedInfoToXContent(UnassignedInfo unassigne } builder.field("last_allocation_status", AllocationDecision.fromAllocationStatus(unassignedInfo.lastAllocationStatus())); builder.endObject(); - return builder; } } diff --git a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportClusterAllocationExplainAction.java b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportClusterAllocationExplainAction.java index 313ee83669017..8e6f029c71013 100644 --- a/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportClusterAllocationExplainAction.java +++ b/server/src/main/java/org/elasticsearch/action/admin/cluster/allocation/TransportClusterAllocationExplainAction.java @@ -28,6 +28,8 @@ import org.elasticsearch.cluster.routing.allocation.ShardAllocationDecision; import org.elasticsearch.cluster.routing.allocation.decider.AllocationDeciders; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ReferenceDocs; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.snapshots.SnapshotsInfoService; import org.elasticsearch.tasks.Task; @@ -160,11 +162,10 @@ public static ShardRouting findShardToExplain(ClusterAllocationExplainRequest re } } if (foundShard == null) { - throw new IllegalArgumentException( - "No shard was specified in the request which means the response should explain a randomly-chosen unassigned shard, " - + "but there are no unassigned shards in this cluster. To explain the allocation of an assigned shard you must " - + "specify the target shard in the request." - ); + throw new IllegalArgumentException(Strings.format(""" + No shard was specified in the request which means the response should explain a randomly-chosen unassigned shard, but \ + there are no unassigned shards in this cluster. To explain the allocation of an assigned shard you must specify the \ + target shard in the request. See %s for more information.""", ReferenceDocs.ALLOCATION_EXPLAIN_API)); } } else { String index = request.getIndex(); diff --git a/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java b/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java index 794a3f38b56bb..efe43fdff4efd 100644 --- a/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java +++ b/server/src/main/java/org/elasticsearch/action/index/IndexRequest.java @@ -23,7 +23,6 @@ import org.elasticsearch.cluster.metadata.IndexAbstraction; import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.routing.IndexRouting; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; @@ -165,10 +164,8 @@ public IndexRequest(@Nullable ShardId shardId, StreamInput in) throws IOExceptio version = in.readLong(); versionType = VersionType.fromValue(in.readByte()); pipeline = readPipelineName(in); - if (in.getTransportVersion().onOrAfter(TransportVersions.V_7_5_0)) { - finalPipeline = readPipelineName(in); - isPipelineResolved = in.readBoolean(); - } + finalPipeline = readPipelineName(in); + isPipelineResolved = in.readBoolean(); isRetry = in.readBoolean(); autoGeneratedTimestamp = in.readLong(); if (in.readBoolean()) { @@ -179,14 +176,8 @@ public IndexRequest(@Nullable ShardId shardId, StreamInput in) throws IOExceptio } ifSeqNo = in.readZLong(); ifPrimaryTerm = in.readVLong(); - if (in.getTransportVersion().onOrAfter(TransportVersions.V_7_10_0)) { - requireAlias = in.readBoolean(); - } else { - requireAlias = false; - } - if (in.getTransportVersion().onOrAfter(TransportVersions.V_7_13_0)) { - dynamicTemplates = in.readMap(StreamInput::readString); - } + requireAlias = in.readBoolean(); + dynamicTemplates = in.readMap(StreamInput::readString); if (in.getTransportVersion().onOrAfter(PIPELINES_HAVE_RUN_FIELD_ADDED) && in.getTransportVersion().before(TransportVersions.V_8_13_0)) { in.readBoolean(); @@ -737,12 +728,8 @@ private void writeBody(StreamOutput out) throws IOException { out.writeLong(version); out.writeByte(versionType.getValue()); out.writeOptionalString(pipeline); - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_5_0)) { - out.writeOptionalString(finalPipeline); - } - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_5_0)) { - out.writeBoolean(isPipelineResolved); - } + out.writeOptionalString(finalPipeline); + out.writeBoolean(isPipelineResolved); out.writeBoolean(isRetry); out.writeLong(autoGeneratedTimestamp); if (contentType != null) { @@ -753,21 +740,8 @@ private void writeBody(StreamOutput out) throws IOException { } out.writeZLong(ifSeqNo); out.writeVLong(ifPrimaryTerm); - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_10_0)) { - out.writeBoolean(requireAlias); - } - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_13_0)) { - out.writeMap(dynamicTemplates, StreamOutput::writeString); - } else { - if (dynamicTemplates.isEmpty() == false) { - throw new IllegalArgumentException( - Strings.format( - "[dynamic_templates] parameter requires all nodes on %s or later", - TransportVersions.V_7_13_0.toReleaseVersion() - ) - ); - } - } + out.writeBoolean(requireAlias); + out.writeMap(dynamicTemplates, StreamOutput::writeString); if (out.getTransportVersion().onOrAfter(PIPELINES_HAVE_RUN_FIELD_ADDED) && out.getTransportVersion().before(TransportVersions.V_8_13_0)) { out.writeBoolean(normalisedBytesParsed != -1L); diff --git a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java index 3fb63591bf3a4..eaddfc6d6592e 100644 --- a/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java +++ b/server/src/main/java/org/elasticsearch/action/search/TransportSearchAction.java @@ -49,6 +49,7 @@ import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting.Property; import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.common.util.FeatureFlag; import org.elasticsearch.common.util.Maps; import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.common.util.concurrent.EsExecutors; @@ -122,6 +123,8 @@ public class TransportSearchAction extends HandledTransportAction SHARD_COUNT_LIMIT_SETTING = Setting.longSetting( "action.search.shard_count.limit", diff --git a/server/src/main/java/org/elasticsearch/action/update/UpdateRequest.java b/server/src/main/java/org/elasticsearch/action/update/UpdateRequest.java index 2cd5258bf4376..211daf2369d99 100644 --- a/server/src/main/java/org/elasticsearch/action/update/UpdateRequest.java +++ b/server/src/main/java/org/elasticsearch/action/update/UpdateRequest.java @@ -157,11 +157,7 @@ public UpdateRequest(@Nullable ShardId shardId, StreamInput in) throws IOExcepti ifPrimaryTerm = in.readVLong(); detectNoop = in.readBoolean(); scriptedUpsert = in.readBoolean(); - if (in.getTransportVersion().onOrAfter(TransportVersions.V_7_10_0)) { - requireAlias = in.readBoolean(); - } else { - requireAlias = false; - } + requireAlias = in.readBoolean(); } public UpdateRequest(String index, String id) { @@ -728,20 +724,18 @@ private void doWrite(StreamOutput out, boolean thin) throws IOException { } out.writeVInt(retryOnConflict); refreshPolicy.writeTo(out); - if (doc == null) { - out.writeBoolean(false); - } else { - out.writeBoolean(true); - // make sure the basics are set - doc.index(index); - doc.id(id); - if (thin) { - doc.writeThin(out); - } else { - doc.writeTo(out); - } - } + writeIndexRequest(out, thin, doc); out.writeOptionalWriteable(fetchSourceContext); + writeIndexRequest(out, thin, upsertRequest); + out.writeBoolean(docAsUpsert); + out.writeZLong(ifSeqNo); + out.writeVLong(ifPrimaryTerm); + out.writeBoolean(detectNoop); + out.writeBoolean(scriptedUpsert); + out.writeBoolean(requireAlias); + } + + private void writeIndexRequest(StreamOutput out, boolean thin, IndexRequest upsertRequest) throws IOException { if (upsertRequest == null) { out.writeBoolean(false); } else { @@ -755,14 +749,6 @@ private void doWrite(StreamOutput out, boolean thin) throws IOException { upsertRequest.writeTo(out); } } - out.writeBoolean(docAsUpsert); - out.writeZLong(ifSeqNo); - out.writeVLong(ifPrimaryTerm); - out.writeBoolean(detectNoop); - out.writeBoolean(scriptedUpsert); - if (out.getTransportVersion().onOrAfter(TransportVersions.V_7_10_0)) { - out.writeBoolean(requireAlias); - } } @Override diff --git a/server/src/main/java/org/elasticsearch/bootstrap/BootstrapChecks.java b/server/src/main/java/org/elasticsearch/bootstrap/BootstrapChecks.java index a60262ff4a097..84811362c08e6 100644 --- a/server/src/main/java/org/elasticsearch/bootstrap/BootstrapChecks.java +++ b/server/src/main/java/org/elasticsearch/bootstrap/BootstrapChecks.java @@ -584,7 +584,7 @@ public BootstrapCheckResult check(BootstrapContext context) { // visible for testing boolean isSystemCallFilterInstalled() { - return Natives.isSystemCallFilterInstalled(); + return NativeAccess.instance().getExecSandboxState() != NativeAccess.ExecSandboxState.NONE; } @Override @@ -608,7 +608,7 @@ public BootstrapCheckResult check(BootstrapContext context) { // visible for testing boolean isSystemCallFilterInstalled() { - return Natives.isSystemCallFilterInstalled(); + return NativeAccess.instance().getExecSandboxState() != NativeAccess.ExecSandboxState.NONE; } // visible for testing diff --git a/server/src/main/java/org/elasticsearch/bootstrap/BootstrapInfo.java b/server/src/main/java/org/elasticsearch/bootstrap/BootstrapInfo.java index f8ad9dd59650c..005375bf38540 100644 --- a/server/src/main/java/org/elasticsearch/bootstrap/BootstrapInfo.java +++ b/server/src/main/java/org/elasticsearch/bootstrap/BootstrapInfo.java @@ -27,16 +27,6 @@ public final class BootstrapInfo { /** no instantiation */ private BootstrapInfo() {} - /** - * Returns true if we successfully loaded native libraries. - *

- * If this returns false, then native operations such as locking - * memory did not work. - */ - public static boolean isNativesAvailable() { - return Natives.JNA_AVAILABLE; - } - /** * Returns true if we were able to lock the process's address space. */ @@ -44,13 +34,6 @@ public static boolean isMemoryLocked() { return NativeAccess.instance().isMemoryLocked(); } - /** - * Returns true if system call filter is installed (supported systems only) - */ - public static boolean isSystemCallFilterInstalled() { - return Natives.isSystemCallFilterInstalled(); - } - /** * Returns information about the console (tty) attached to the server process, or {@code null} * if no console is attached. diff --git a/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java b/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java index 082e1dd9257e0..3fc659cb8065d 100644 --- a/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java +++ b/server/src/main/java/org/elasticsearch/bootstrap/Elasticsearch.java @@ -293,7 +293,7 @@ static void initializeNatives(final Path tmpFile, final boolean mlockAll, final * * TODO: should we fail hard here if system call filters fail to install, or remain lenient in non-production environments? */ - Natives.tryInstallSystemCallFilter(tmpFile); + nativeAccess.tryInstallExecSandbox(); } // mlockall if requested @@ -316,13 +316,6 @@ static void initializeNatives(final Path tmpFile, final boolean mlockAll, final } } - // force remainder of JNA to be loaded (if available). - try { - JNAKernel32Library.getInstance(); - } catch (Exception ignored) { - // we've already logged this. - } - // init lucene random seed. it will use /dev/urandom where available: StringHelper.randomId(); diff --git a/server/src/main/java/org/elasticsearch/bootstrap/JNAKernel32Library.java b/server/src/main/java/org/elasticsearch/bootstrap/JNAKernel32Library.java deleted file mode 100644 index 01d9a122138f1..0000000000000 --- a/server/src/main/java/org/elasticsearch/bootstrap/JNAKernel32Library.java +++ /dev/null @@ -1,255 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.bootstrap; - -import com.sun.jna.IntegerType; -import com.sun.jna.Native; -import com.sun.jna.NativeLong; -import com.sun.jna.Pointer; -import com.sun.jna.Structure; -import com.sun.jna.WString; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.util.Constants; - -import java.util.Arrays; -import java.util.List; - -/** - * Library for Windows/Kernel32 - */ -final class JNAKernel32Library { - - private static final Logger logger = LogManager.getLogger(JNAKernel32Library.class); - - // Native library instance must be kept around for the same reason. - private static final class Holder { - private static final JNAKernel32Library instance = new JNAKernel32Library(); - } - - private JNAKernel32Library() { - if (Constants.WINDOWS) { - try { - Native.register("kernel32"); - logger.debug("windows/Kernel32 library loaded"); - } catch (NoClassDefFoundError e) { - logger.warn("JNA not found. native methods and handlers will be disabled."); - } catch (UnsatisfiedLinkError e) { - logger.warn("unable to link Windows/Kernel32 library. native methods and handlers will be disabled."); - } - } - } - - static JNAKernel32Library getInstance() { - return Holder.instance; - } - - /** - * Memory protection constraints - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/aa366786%28v=vs.85%29.aspx - */ - public static final int PAGE_NOACCESS = 0x0001; - public static final int PAGE_GUARD = 0x0100; - public static final int MEM_COMMIT = 0x1000; - - /** - * Contains information about a range of pages in the virtual address space of a process. - * The VirtualQuery and VirtualQueryEx functions use this structure. - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/aa366775%28v=vs.85%29.aspx - */ - public static class MemoryBasicInformation extends Structure { - public Pointer BaseAddress; - public Pointer AllocationBase; - public NativeLong AllocationProtect; - public SizeT RegionSize; - public NativeLong State; - public NativeLong Protect; - public NativeLong Type; - - @Override - protected List getFieldOrder() { - return Arrays.asList("BaseAddress", "AllocationBase", "AllocationProtect", "RegionSize", "State", "Protect", "Type"); - } - } - - public static class SizeT extends IntegerType { - - // JNA requires this no-arg constructor to be public, - // otherwise it fails to register kernel32 library - public SizeT() { - this(0); - } - - SizeT(long value) { - super(Native.SIZE_T_SIZE, value); - } - - } - - /** - * Locks the specified region of the process's virtual address space into physical - * memory, ensuring that subsequent access to the region will not incur a page fault. - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/aa366895%28v=vs.85%29.aspx - * - * @param address A pointer to the base address of the region of pages to be locked. - * @param size The size of the region to be locked, in bytes. - * @return true if the function succeeds - */ - native boolean VirtualLock(Pointer address, SizeT size); - - /** - * Retrieves information about a range of pages within the virtual address space of a specified process. - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/aa366907%28v=vs.85%29.aspx - * - * @param handle A handle to the process whose memory information is queried. - * @param address A pointer to the base address of the region of pages to be queried. - * @param memoryInfo A pointer to a structure in which information about the specified page range is returned. - * @param length The size of the buffer pointed to by the memoryInfo parameter, in bytes. - * @return the actual number of bytes returned in the information buffer. - */ - native int VirtualQueryEx(Pointer handle, Pointer address, MemoryBasicInformation memoryInfo, int length); - - /** - * Sets the minimum and maximum working set sizes for the specified process. - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/ms686234%28v=vs.85%29.aspx - * - * @param handle A handle to the process whose working set sizes is to be set. - * @param minSize The minimum working set size for the process, in bytes. - * @param maxSize The maximum working set size for the process, in bytes. - * @return true if the function succeeds. - */ - native boolean SetProcessWorkingSetSize(Pointer handle, SizeT minSize, SizeT maxSize); - - /** - * Retrieves a pseudo handle for the current process. - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/ms683179%28v=vs.85%29.aspx - * - * @return a pseudo handle to the current process. - */ - native Pointer GetCurrentProcess(); - - /** - * Closes an open object handle. - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/ms724211%28v=vs.85%29.aspx - * - * @param handle A valid handle to an open object. - * @return true if the function succeeds. - */ - native boolean CloseHandle(Pointer handle); - - /** - * Retrieves the short path form of the specified path. See - * {@code GetShortPathName}. - * - * @param lpszLongPath the path string - * @param lpszShortPath a buffer to receive the short name - * @param cchBuffer the size of the buffer - * @return the length of the string copied into {@code lpszShortPath}, otherwise zero for failure - */ - native int GetShortPathNameW(WString lpszLongPath, char[] lpszShortPath, int cchBuffer); - - /** - * Creates or opens a new job object - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/ms682409%28v=vs.85%29.aspx - * - * @param jobAttributes security attributes - * @param name job name - * @return job handle if the function succeeds - */ - native Pointer CreateJobObjectW(Pointer jobAttributes, String name); - - /** - * Associates a process with an existing job - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/ms681949%28v=vs.85%29.aspx - * - * @param job job handle - * @param process process handle - * @return true if the function succeeds - */ - native boolean AssignProcessToJobObject(Pointer job, Pointer process); - - /** - * Basic limit information for a job object - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/ms684147%28v=vs.85%29.aspx - */ - public static class JOBOBJECT_BASIC_LIMIT_INFORMATION extends Structure implements Structure.ByReference { - public long PerProcessUserTimeLimit; - public long PerJobUserTimeLimit; - public int LimitFlags; - public SizeT MinimumWorkingSetSize; - public SizeT MaximumWorkingSetSize; - public int ActiveProcessLimit; - public Pointer Affinity; - public int PriorityClass; - public int SchedulingClass; - - @Override - protected List getFieldOrder() { - return Arrays.asList( - "PerProcessUserTimeLimit", - "PerJobUserTimeLimit", - "LimitFlags", - "MinimumWorkingSetSize", - "MaximumWorkingSetSize", - "ActiveProcessLimit", - "Affinity", - "PriorityClass", - "SchedulingClass" - ); - } - } - - /** - * Constant for JOBOBJECT_BASIC_LIMIT_INFORMATION in Query/Set InformationJobObject - */ - static final int JOBOBJECT_BASIC_LIMIT_INFORMATION_CLASS = 2; - - /** - * Constant for LimitFlags, indicating a process limit has been set - */ - static final int JOB_OBJECT_LIMIT_ACTIVE_PROCESS = 8; - - /** - * Get job limit and state information - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/ms684925%28v=vs.85%29.aspx - * - * @param job job handle - * @param infoClass information class constant - * @param info pointer to information structure - * @param infoLength size of information structure - * @param returnLength length of data written back to structure (or null if not wanted) - * @return true if the function succeeds - */ - native boolean QueryInformationJobObject(Pointer job, int infoClass, Pointer info, int infoLength, Pointer returnLength); - - /** - * Set job limit and state information - * - * https://msdn.microsoft.com/en-us/library/windows/desktop/ms686216%28v=vs.85%29.aspx - * - * @param job job handle - * @param infoClass information class constant - * @param info pointer to information structure - * @param infoLength size of information structure - * @return true if the function succeeds - */ - native boolean SetInformationJobObject(Pointer job, int infoClass, Pointer info, int infoLength); -} diff --git a/server/src/main/java/org/elasticsearch/bootstrap/JNANatives.java b/server/src/main/java/org/elasticsearch/bootstrap/JNANatives.java deleted file mode 100644 index ba4e90ee2c6c1..0000000000000 --- a/server/src/main/java/org/elasticsearch/bootstrap/JNANatives.java +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.bootstrap; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; - -import java.nio.file.Path; - -/** - * This class performs the actual work with JNA and library bindings to call native methods. It should only be used after - * we are sure that the JNA classes are available to the JVM - */ -class JNANatives { - - /** no instantiation */ - private JNANatives() {} - - private static final Logger logger = LogManager.getLogger(JNANatives.class); - - // Set to true, in case native system call filter install was successful - static boolean LOCAL_SYSTEM_CALL_FILTER = false; - // Set to true, in case policy can be applied to all threads of the process (even existing ones) - // otherwise they are only inherited for new threads (ES app threads) - static boolean LOCAL_SYSTEM_CALL_FILTER_ALL = false; - - static void tryInstallSystemCallFilter(Path tmpFile) { - try { - int ret = SystemCallFilter.init(tmpFile); - LOCAL_SYSTEM_CALL_FILTER = true; - if (ret == 1) { - LOCAL_SYSTEM_CALL_FILTER_ALL = true; - } - } catch (Exception e) { - // this is likely to happen unless the kernel is newish, its a best effort at the moment - // so we log stacktrace at debug for now... - if (logger.isDebugEnabled()) { - logger.debug("unable to install syscall filter", e); - } - logger.warn("unable to install syscall filter: ", e); - } - } - -} diff --git a/server/src/main/java/org/elasticsearch/bootstrap/Natives.java b/server/src/main/java/org/elasticsearch/bootstrap/Natives.java deleted file mode 100644 index c792d1e0bfad0..0000000000000 --- a/server/src/main/java/org/elasticsearch/bootstrap/Natives.java +++ /dev/null @@ -1,69 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.bootstrap; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.elasticsearch.common.ReferenceDocs; - -import java.lang.invoke.MethodHandles; -import java.nio.file.Path; -import java.util.Locale; - -/** - * The Natives class is a wrapper class that checks if the classes necessary for calling native methods are available on - * startup. If they are not available, this class will avoid calling code that loads these classes. - */ -final class Natives { - /** no instantiation */ - private Natives() {} - - private static final Logger logger = LogManager.getLogger(Natives.class); - - // marker to determine if the JNA class files are available to the JVM - static final boolean JNA_AVAILABLE; - - static { - boolean v = false; - try { - // load one of the main JNA classes to see if the classes are available. this does not ensure that all native - // libraries are available, only the ones necessary by JNA to function - MethodHandles.publicLookup().ensureInitialized(com.sun.jna.Native.class); - v = true; - } catch (IllegalAccessException e) { - throw new AssertionError(e); - } catch (UnsatisfiedLinkError e) { - logger.warn( - String.format( - Locale.ROOT, - "unable to load JNA native support library, native methods will be disabled. See %s", - ReferenceDocs.EXECUTABLE_JNA_TMPDIR - ), - e - ); - } - JNA_AVAILABLE = v; - } - - static void tryInstallSystemCallFilter(Path tmpFile) { - if (JNA_AVAILABLE == false) { - logger.warn("cannot install system call filter because JNA is not available"); - return; - } - JNANatives.tryInstallSystemCallFilter(tmpFile); - } - - static boolean isSystemCallFilterInstalled() { - if (JNA_AVAILABLE == false) { - return false; - } - return JNANatives.LOCAL_SYSTEM_CALL_FILTER; - } - -} diff --git a/server/src/main/java/org/elasticsearch/bootstrap/SystemCallFilter.java b/server/src/main/java/org/elasticsearch/bootstrap/SystemCallFilter.java deleted file mode 100644 index 0ab855d1d5f3a..0000000000000 --- a/server/src/main/java/org/elasticsearch/bootstrap/SystemCallFilter.java +++ /dev/null @@ -1,641 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.bootstrap; - -import com.sun.jna.Library; -import com.sun.jna.Memory; -import com.sun.jna.Native; -import com.sun.jna.NativeLong; -import com.sun.jna.Pointer; -import com.sun.jna.Structure; -import com.sun.jna.ptr.PointerByReference; - -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; -import org.apache.lucene.util.Constants; -import org.elasticsearch.core.IOUtils; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.file.Files; -import java.nio.file.Path; -import java.util.Arrays; -import java.util.Collections; -import java.util.List; -import java.util.Map; - -/** - * Installs a system call filter to block process execution. - *

- * This is supported on Linux, Solaris, FreeBSD, OpenBSD, Mac OS X, and Windows. - *

- * On Linux it currently supports amd64 and i386 architectures, requires Linux kernel 3.5 or above, and requires - * {@code CONFIG_SECCOMP} and {@code CONFIG_SECCOMP_FILTER} compiled into the kernel. - *

- * On Linux BPF Filters are installed using either {@code seccomp(2)} (3.17+) or {@code prctl(2)} (3.5+). {@code seccomp(2)} - * is preferred, as it allows filters to be applied to any existing threads in the process, and one motivation - * here is to protect against bugs in the JVM. Otherwise, code will fall back to the {@code prctl(2)} method - * which will at least protect elasticsearch application threads. - *

- * Linux BPF filters will return {@code EACCES} (Access Denied) for the following system calls: - *

    - *
  • {@code execve}
  • - *
  • {@code fork}
  • - *
  • {@code vfork}
  • - *
  • {@code execveat}
  • - *
- *

- * On Solaris 10 or higher, the following privileges are dropped with {@code priv_set(3C)}: - *

    - *
  • {@code PRIV_PROC_FORK}
  • - *
  • {@code PRIV_PROC_EXEC}
  • - *
- *

- * On BSD systems, process creation is restricted with {@code setrlimit(RLIMIT_NPROC)}. - *

- * On Mac OS X Leopard or above, a custom {@code sandbox(7)} ("Seatbelt") profile is installed that - * denies the following rules: - *

    - *
  • {@code process-fork}
  • - *
  • {@code process-exec}
  • - *
- *

- * On Windows, process creation is restricted with {@code SetInformationJobObject/ActiveProcessLimit}. - *

- * This is not intended as a sandbox. It is another level of security, mostly intended to annoy - * security researchers and make their lives more difficult in achieving "remote execution" exploits. - * @see - * http://www.kernel.org/doc/Documentation/prctl/seccomp_filter.txt - * @see - * https://reverse.put.as/wp-content/uploads/2011/06/The-Apple-Sandbox-BHDC2011-Paper.pdf - * @see - * https://docs.oracle.com/cd/E23824_01/html/821-1456/prbac-2.html - */ -// not an example of how to write code!!! -final class SystemCallFilter { - private static final Logger logger = LogManager.getLogger(SystemCallFilter.class); - - // Linux implementation, based on seccomp(2) or prctl(2) with bpf filtering - - /** Access to non-standard Linux libc methods */ - interface LinuxLibrary extends Library { - /** - * maps to prctl(2) - */ - int prctl(int option, NativeLong arg2, NativeLong arg3, NativeLong arg4, NativeLong arg5); - - /** - * used to call seccomp(2), its too new... - * this is the only way, DON'T use it on some other architecture unless you know wtf you are doing - */ - NativeLong syscall(NativeLong number, Object... args); - } - - // null if unavailable or something goes wrong. - private static final LinuxLibrary linux_libc; - - static { - LinuxLibrary lib = null; - if (Constants.LINUX) { - try { - lib = Native.loadLibrary("c", LinuxLibrary.class); - } catch (UnsatisfiedLinkError e) { - logger.warn("unable to link C library. native methods (seccomp) will be disabled.", e); - } - } - linux_libc = lib; - } - - /** the preferred method is seccomp(2), since we can apply to all threads of the process */ - static final int SECCOMP_SET_MODE_FILTER = 1; // since Linux 3.17 - static final int SECCOMP_FILTER_FLAG_TSYNC = 1; // since Linux 3.17 - - /** otherwise, we can use prctl(2), which will at least protect ES application threads */ - static final int PR_GET_NO_NEW_PRIVS = 39; // since Linux 3.5 - static final int PR_SET_NO_NEW_PRIVS = 38; // since Linux 3.5 - static final int PR_GET_SECCOMP = 21; // since Linux 2.6.23 - static final int PR_SET_SECCOMP = 22; // since Linux 2.6.23 - static final long SECCOMP_MODE_FILTER = 2; // since Linux Linux 3.5 - - /** corresponds to struct sock_filter */ - static final class SockFilter { - short code; // insn - byte jt; // number of insn to jump (skip) if true - byte jf; // number of insn to jump (skip) if false - int k; // additional data - - SockFilter(short code, byte jt, byte jf, int k) { - this.code = code; - this.jt = jt; - this.jf = jf; - this.k = k; - } - } - - /** corresponds to struct sock_fprog */ - public static final class SockFProg extends Structure implements Structure.ByReference { - public short len; // number of filters - public Pointer filter; // filters - - SockFProg(SockFilter filters[]) { - len = (short) filters.length; - // serialize struct sock_filter * explicitly, its less confusing than the JNA magic we would need - Memory filter = new Memory(len * 8); - ByteBuffer bbuf = filter.getByteBuffer(0, len * 8); - bbuf.order(ByteOrder.nativeOrder()); // little endian - for (SockFilter f : filters) { - bbuf.putShort(f.code); - bbuf.put(f.jt); - bbuf.put(f.jf); - bbuf.putInt(f.k); - } - this.filter = filter; - } - - @Override - protected List getFieldOrder() { - return Arrays.asList("len", "filter"); - } - } - - // BPF "macros" and constants - static final int BPF_LD = 0x00; - static final int BPF_W = 0x00; - static final int BPF_ABS = 0x20; - static final int BPF_JMP = 0x05; - static final int BPF_JEQ = 0x10; - static final int BPF_JGE = 0x30; - static final int BPF_JGT = 0x20; - static final int BPF_RET = 0x06; - static final int BPF_K = 0x00; - - static SockFilter BPF_STMT(int code, int k) { - return new SockFilter((short) code, (byte) 0, (byte) 0, k); - } - - static SockFilter BPF_JUMP(int code, int k, int jt, int jf) { - return new SockFilter((short) code, (byte) jt, (byte) jf, k); - } - - static final int SECCOMP_RET_ERRNO = 0x00050000; - static final int SECCOMP_RET_DATA = 0x0000FFFF; - static final int SECCOMP_RET_ALLOW = 0x7FFF0000; - - // some errno constants for error checking/handling - static final int EACCES = 0x0D; - static final int EFAULT = 0x0E; - static final int EINVAL = 0x16; - static final int ENOSYS = 0x26; - - // offsets that our BPF checks - // check with offsetof() when adding a new arch, move to Arch if different. - static final int SECCOMP_DATA_NR_OFFSET = 0x00; - static final int SECCOMP_DATA_ARCH_OFFSET = 0x04; - - record Arch( - int audit, // AUDIT_ARCH_XXX constant from linux/audit.h - int limit, // syscall limit (necessary for blacklisting on amd64, to ban 32-bit syscalls) - int fork, // __NR_fork - int vfork, // __NR_vfork - int execve, // __NR_execve - int execveat, // __NR_execveat - int seccomp // __NR_seccomp - ) {} - - /** supported architectures map keyed by os.arch */ - private static final Map ARCHITECTURES; - static { - ARCHITECTURES = Map.of( - "amd64", - new Arch(0xC000003E, 0x3FFFFFFF, 57, 58, 59, 322, 317), - "aarch64", - new Arch(0xC00000B7, 0xFFFFFFFF, 1079, 1071, 221, 281, 277) - ); - } - - /** invokes prctl() from linux libc library */ - private static int linux_prctl(int option, long arg2, long arg3, long arg4, long arg5) { - return linux_libc.prctl(option, new NativeLong(arg2), new NativeLong(arg3), new NativeLong(arg4), new NativeLong(arg5)); - } - - /** invokes syscall() from linux libc library */ - private static long linux_syscall(long number, Object... args) { - return linux_libc.syscall(new NativeLong(number), args).longValue(); - } - - /** try to install our BPF filters via seccomp() or prctl() to block execution */ - private static int linuxImpl() { - // first be defensive: we can give nice errors this way, at the very least. - // also, some of these security features get backported to old versions, checking kernel version here is a big no-no! - final Arch arch = ARCHITECTURES.get(Constants.OS_ARCH); - boolean supported = Constants.LINUX && arch != null; - if (supported == false) { - throw new UnsupportedOperationException("seccomp unavailable: '" + Constants.OS_ARCH + "' architecture unsupported"); - } - - // we couldn't link methods, could be some really ancient kernel (e.g. < 2.1.57) or some bug - if (linux_libc == null) { - throw new UnsupportedOperationException( - "seccomp unavailable: could not link methods. requires kernel 3.5+ " - + "with CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER compiled in" - ); - } - - // try to check system calls really are who they claim - // you never know (e.g. https://chromium.googlesource.com/chromium/src.git/+/master/sandbox/linux/seccomp-bpf/sandbox_bpf.cc#57) - final int bogusArg = 0xf7a46a5c; - - // test seccomp(BOGUS) - long ret = linux_syscall(arch.seccomp, bogusArg); - if (ret != -1) { - throw new UnsupportedOperationException("seccomp unavailable: seccomp(BOGUS_OPERATION) returned " + ret); - } else { - int errno = Native.getLastError(); - switch (errno) { - case ENOSYS: - break; // ok - case EINVAL: - break; // ok - default: - throw new UnsupportedOperationException("seccomp(BOGUS_OPERATION): " + JNACLibrary.strerror(errno)); - } - } - - // test seccomp(VALID, BOGUS) - ret = linux_syscall(arch.seccomp, SECCOMP_SET_MODE_FILTER, bogusArg); - if (ret != -1) { - throw new UnsupportedOperationException("seccomp unavailable: seccomp(SECCOMP_SET_MODE_FILTER, BOGUS_FLAG) returned " + ret); - } else { - int errno = Native.getLastError(); - switch (errno) { - case ENOSYS: - break; // ok - case EINVAL: - break; // ok - default: - throw new UnsupportedOperationException("seccomp(SECCOMP_SET_MODE_FILTER, BOGUS_FLAG): " + JNACLibrary.strerror(errno)); - } - } - - // test prctl(BOGUS) - ret = linux_prctl(bogusArg, 0, 0, 0, 0); - if (ret != -1) { - throw new UnsupportedOperationException("seccomp unavailable: prctl(BOGUS_OPTION) returned " + ret); - } else { - int errno = Native.getLastError(); - switch (errno) { - case ENOSYS: - break; // ok - case EINVAL: - break; // ok - default: - throw new UnsupportedOperationException("prctl(BOGUS_OPTION): " + JNACLibrary.strerror(errno)); - } - } - - // now just normal defensive checks - - // check for GET_NO_NEW_PRIVS - switch (linux_prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0)) { - case 0: - break; // not yet set - case 1: - break; // already set by caller - default: - int errno = Native.getLastError(); - if (errno == EINVAL) { - // friendly error, this will be the typical case for an old kernel - throw new UnsupportedOperationException( - "seccomp unavailable: requires kernel 3.5+ with" + " CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER compiled in" - ); - } else { - throw new UnsupportedOperationException("prctl(PR_GET_NO_NEW_PRIVS): " + JNACLibrary.strerror(errno)); - } - } - // check for SECCOMP - switch (linux_prctl(PR_GET_SECCOMP, 0, 0, 0, 0)) { - case 0: - break; // not yet set - case 2: - break; // already in filter mode by caller - default: - int errno = Native.getLastError(); - if (errno == EINVAL) { - throw new UnsupportedOperationException( - "seccomp unavailable: CONFIG_SECCOMP not compiled into kernel," - + " CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER are needed" - ); - } else { - throw new UnsupportedOperationException("prctl(PR_GET_SECCOMP): " + JNACLibrary.strerror(errno)); - } - } - // check for SECCOMP_MODE_FILTER - if (linux_prctl(PR_SET_SECCOMP, SECCOMP_MODE_FILTER, 0, 0, 0) != 0) { - int errno = Native.getLastError(); - switch (errno) { - case EFAULT: - break; // available - case EINVAL: - throw new UnsupportedOperationException( - "seccomp unavailable: CONFIG_SECCOMP_FILTER not" - + " compiled into kernel, CONFIG_SECCOMP and CONFIG_SECCOMP_FILTER are needed" - ); - default: - throw new UnsupportedOperationException("prctl(PR_SET_SECCOMP): " + JNACLibrary.strerror(errno)); - } - } - - // ok, now set PR_SET_NO_NEW_PRIVS, needed to be able to set a seccomp filter as ordinary user - if (linux_prctl(PR_SET_NO_NEW_PRIVS, 1, 0, 0, 0) != 0) { - throw new UnsupportedOperationException("prctl(PR_SET_NO_NEW_PRIVS): " + JNACLibrary.strerror(Native.getLastError())); - } - - // check it worked - if (linux_prctl(PR_GET_NO_NEW_PRIVS, 0, 0, 0, 0) != 1) { - throw new UnsupportedOperationException( - "seccomp filter did not really succeed: prctl(PR_GET_NO_NEW_PRIVS): " + JNACLibrary.strerror(Native.getLastError()) - ); - } - - // BPF installed to check arch, limit, then syscall. - // See https://www.kernel.org/doc/Documentation/prctl/seccomp_filter.txt for details. - SockFilter insns[] = { - /* 1 */ BPF_STMT(BPF_LD + BPF_W + BPF_ABS, SECCOMP_DATA_ARCH_OFFSET), // - /* 2 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.audit, 0, 7), // if (arch != audit) goto fail; - /* 3 */ BPF_STMT(BPF_LD + BPF_W + BPF_ABS, SECCOMP_DATA_NR_OFFSET), // - /* 4 */ BPF_JUMP(BPF_JMP + BPF_JGT + BPF_K, arch.limit, 5, 0), // if (syscall > LIMIT) goto fail; - /* 5 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.fork, 4, 0), // if (syscall == FORK) goto fail; - /* 6 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.vfork, 3, 0), // if (syscall == VFORK) goto fail; - /* 7 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.execve, 2, 0), // if (syscall == EXECVE) goto fail; - /* 8 */ BPF_JUMP(BPF_JMP + BPF_JEQ + BPF_K, arch.execveat, 1, 0), // if (syscall == EXECVEAT) goto fail; - /* 9 */ BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_ALLOW), // pass: return OK; - /* 10 */ BPF_STMT(BPF_RET + BPF_K, SECCOMP_RET_ERRNO | (EACCES & SECCOMP_RET_DATA)), // fail: return EACCES; - }; - // seccomp takes a long, so we pass it one explicitly to keep the JNA simple - SockFProg prog = new SockFProg(insns); - prog.write(); - long pointer = Pointer.nativeValue(prog.getPointer()); - - int method = 1; - // install filter, if this works, after this there is no going back! - // first try it with seccomp(SECCOMP_SET_MODE_FILTER), falling back to prctl() - if (linux_syscall(arch.seccomp, SECCOMP_SET_MODE_FILTER, SECCOMP_FILTER_FLAG_TSYNC, new NativeLong(pointer)) != 0) { - method = 0; - int errno1 = Native.getLastError(); - if (logger.isDebugEnabled()) { - logger.debug( - "seccomp(SECCOMP_SET_MODE_FILTER): {}, falling back to prctl(PR_SET_SECCOMP)...", - JNACLibrary.strerror(errno1) - ); - } - if (linux_prctl(PR_SET_SECCOMP, SECCOMP_MODE_FILTER, pointer, 0, 0) != 0) { - int errno2 = Native.getLastError(); - throw new UnsupportedOperationException( - "seccomp(SECCOMP_SET_MODE_FILTER): " - + JNACLibrary.strerror(errno1) - + ", prctl(PR_SET_SECCOMP): " - + JNACLibrary.strerror(errno2) - ); - } - } - - // now check that the filter was really installed, we should be in filter mode. - if (linux_prctl(PR_GET_SECCOMP, 0, 0, 0, 0) != 2) { - throw new UnsupportedOperationException( - "seccomp filter installation did not really succeed. seccomp(PR_GET_SECCOMP): " - + JNACLibrary.strerror(Native.getLastError()) - ); - } - - logger.debug("Linux seccomp filter installation successful, threads: [{}]", method == 1 ? "all" : "app"); - return method; - } - - // OS X implementation via sandbox(7) - - /** Access to non-standard OS X libc methods */ - interface MacLibrary extends Library { - /** - * maps to sandbox_init(3), since Leopard - */ - int sandbox_init(String profile, long flags, PointerByReference errorbuf); - - /** - * releases memory when an error occurs during initialization (e.g. syntax bug) - */ - void sandbox_free_error(Pointer errorbuf); - } - - // null if unavailable, or something goes wrong. - private static final MacLibrary libc_mac; - - static { - MacLibrary lib = null; - if (Constants.MAC_OS_X) { - try { - lib = Native.loadLibrary("c", MacLibrary.class); - } catch (UnsatisfiedLinkError e) { - logger.warn("unable to link C library. native methods (seatbelt) will be disabled.", e); - } - } - libc_mac = lib; - } - - /** The only supported flag... */ - static final int SANDBOX_NAMED = 1; - /** Allow everything except process fork and execution */ - static final String SANDBOX_RULES = "(version 1) (allow default) (deny process-fork) (deny process-exec)"; - - /** try to install our custom rule profile into sandbox_init() to block execution */ - private static void macImpl(Path tmpFile) throws IOException { - // first be defensive: we can give nice errors this way, at the very least. - boolean supported = Constants.MAC_OS_X; - if (supported == false) { - throw new IllegalStateException("bug: should not be trying to initialize seatbelt for an unsupported OS"); - } - - // we couldn't link methods, could be some really ancient OS X (< Leopard) or some bug - if (libc_mac == null) { - throw new UnsupportedOperationException("seatbelt unavailable: could not link methods. requires Leopard or above."); - } - - // write rules to a temporary file, which will be passed to sandbox_init() - Path rules = Files.createTempFile(tmpFile, "es", "sb"); - Files.write(rules, Collections.singleton(SANDBOX_RULES)); - - boolean success = false; - try { - PointerByReference errorRef = new PointerByReference(); - int ret = libc_mac.sandbox_init(rules.toAbsolutePath().toString(), SANDBOX_NAMED, errorRef); - // if sandbox_init() fails, add the message from the OS (e.g. syntax error) and free the buffer - if (ret != 0) { - Pointer errorBuf = errorRef.getValue(); - RuntimeException e = new UnsupportedOperationException("sandbox_init(): " + errorBuf.getString(0)); - libc_mac.sandbox_free_error(errorBuf); - throw e; - } - logger.debug("OS X seatbelt initialization successful"); - success = true; - } finally { - if (success) { - Files.delete(rules); - } else { - IOUtils.deleteFilesIgnoringExceptions(rules); - } - } - } - - // Solaris implementation via priv_set(3C) - - /** Access to non-standard Solaris libc methods */ - interface SolarisLibrary extends Library { - /** - * see priv_set(3C), a convenience method for setppriv(2). - */ - int priv_set(int op, String which, String... privs); - } - - // null if unavailable, or something goes wrong. - private static final SolarisLibrary libc_solaris; - - static { - SolarisLibrary lib = null; - if (Constants.SUN_OS) { - try { - lib = Native.loadLibrary("c", SolarisLibrary.class); - } catch (UnsatisfiedLinkError e) { - logger.warn("unable to link C library. native methods (priv_set) will be disabled.", e); - } - } - libc_solaris = lib; - } - - // constants for priv_set(2) - static final int PRIV_OFF = 1; - static final String PRIV_ALLSETS = null; - // see privileges(5) for complete list of these - static final String PRIV_PROC_FORK = "proc_fork"; - static final String PRIV_PROC_EXEC = "proc_exec"; - - static void solarisImpl() { - // first be defensive: we can give nice errors this way, at the very least. - boolean supported = Constants.SUN_OS; - if (supported == false) { - throw new IllegalStateException("bug: should not be trying to initialize priv_set for an unsupported OS"); - } - - // we couldn't link methods, could be some really ancient Solaris or some bug - if (libc_solaris == null) { - throw new UnsupportedOperationException("priv_set unavailable: could not link methods. requires Solaris 10+"); - } - - // drop a null-terminated list of privileges - if (libc_solaris.priv_set(PRIV_OFF, PRIV_ALLSETS, PRIV_PROC_FORK, PRIV_PROC_EXEC, null) != 0) { - throw new UnsupportedOperationException("priv_set unavailable: priv_set(): " + JNACLibrary.strerror(Native.getLastError())); - } - - logger.debug("Solaris priv_set initialization successful"); - } - - // BSD implementation via setrlimit(2) - - // TODO: add OpenBSD to Lucene Constants - // TODO: JNA doesn't have netbsd support, but this mechanism should work there too. - static final boolean OPENBSD = Constants.OS_NAME.startsWith("OpenBSD"); - - // not a standard limit, means something different on linux, etc! - static final int RLIMIT_NPROC = 7; - - static void bsdImpl() { - boolean supported = Constants.FREE_BSD || OPENBSD || Constants.MAC_OS_X; - if (supported == false) { - throw new IllegalStateException("bug: should not be trying to initialize RLIMIT_NPROC for an unsupported OS"); - } - - JNACLibrary.Rlimit limit = new JNACLibrary.Rlimit(); - limit.rlim_cur.setValue(0); - limit.rlim_max.setValue(0); - if (JNACLibrary.setrlimit(RLIMIT_NPROC, limit) != 0) { - throw new UnsupportedOperationException("RLIMIT_NPROC unavailable: " + JNACLibrary.strerror(Native.getLastError())); - } - - logger.debug("BSD RLIMIT_NPROC initialization successful"); - } - - // windows impl via job ActiveProcessLimit - - static void windowsImpl() { - if (Constants.WINDOWS == false) { - throw new IllegalStateException("bug: should not be trying to initialize ActiveProcessLimit for an unsupported OS"); - } - - JNAKernel32Library lib = JNAKernel32Library.getInstance(); - - // create a new Job - Pointer job = lib.CreateJobObjectW(null, null); - if (job == null) { - throw new UnsupportedOperationException("CreateJobObject: " + Native.getLastError()); - } - - try { - // retrieve the current basic limits of the job - int clazz = JNAKernel32Library.JOBOBJECT_BASIC_LIMIT_INFORMATION_CLASS; - JNAKernel32Library.JOBOBJECT_BASIC_LIMIT_INFORMATION limits = new JNAKernel32Library.JOBOBJECT_BASIC_LIMIT_INFORMATION(); - limits.write(); - if (lib.QueryInformationJobObject(job, clazz, limits.getPointer(), limits.size(), null) == false) { - throw new UnsupportedOperationException("QueryInformationJobObject: " + Native.getLastError()); - } - limits.read(); - // modify the number of active processes to be 1 (exactly the one process we will add to the job). - limits.ActiveProcessLimit = 1; - limits.LimitFlags = JNAKernel32Library.JOB_OBJECT_LIMIT_ACTIVE_PROCESS; - limits.write(); - if (lib.SetInformationJobObject(job, clazz, limits.getPointer(), limits.size()) == false) { - throw new UnsupportedOperationException("SetInformationJobObject: " + Native.getLastError()); - } - // assign ourselves to the job - if (lib.AssignProcessToJobObject(job, lib.GetCurrentProcess()) == false) { - throw new UnsupportedOperationException("AssignProcessToJobObject: " + Native.getLastError()); - } - } finally { - lib.CloseHandle(job); - } - - logger.debug("Windows ActiveProcessLimit initialization successful"); - } - - /** - * Attempt to drop the capability to execute for the process. - *

- * This is best effort and OS and architecture dependent. It may throw any Throwable. - * @return 0 if we can do this for application threads, 1 for the entire process - */ - static int init(Path tmpFile) throws Exception { - if (Constants.LINUX) { - return linuxImpl(); - } else if (Constants.MAC_OS_X) { - // try to enable both mechanisms if possible - bsdImpl(); - macImpl(tmpFile); - return 1; - } else if (Constants.SUN_OS) { - solarisImpl(); - return 1; - } else if (Constants.FREE_BSD || OPENBSD) { - bsdImpl(); - return 1; - } else if (Constants.WINDOWS) { - windowsImpl(); - return 1; - } else { - throw new UnsupportedOperationException("syscall filtering not supported for OS: '" + Constants.OS_NAME + "'"); - } - } -} diff --git a/server/src/main/java/org/elasticsearch/cluster/SnapshotsInProgress.java b/server/src/main/java/org/elasticsearch/cluster/SnapshotsInProgress.java index 532a33d07b25d..b6fb370991a93 100644 --- a/server/src/main/java/org/elasticsearch/cluster/SnapshotsInProgress.java +++ b/server/src/main/java/org/elasticsearch/cluster/SnapshotsInProgress.java @@ -27,6 +27,8 @@ import org.elasticsearch.index.Index; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; import org.elasticsearch.repositories.IndexId; import org.elasticsearch.repositories.RepositoryOperation; import org.elasticsearch.repositories.RepositoryShardId; @@ -58,6 +60,8 @@ */ public class SnapshotsInProgress extends AbstractNamedDiffable implements Custom { + private static final Logger logger = LogManager.getLogger(SnapshotsInProgress.class); + public static final SnapshotsInProgress EMPTY = new SnapshotsInProgress(Map.of(), Set.of()); public static final String TYPE = "snapshots"; @@ -207,6 +211,17 @@ public Map> obsoleteGenerations(String r // We moved from a non-null generation successful generation to a different non-null successful generation // so the original generation is clearly obsolete because it was in-flight before and is now unreferenced everywhere. obsoleteGenerations.computeIfAbsent(repositoryShardId, ignored -> new HashSet<>()).add(oldStatus.generation()); + logger.debug( + """ + Marking shard generation [{}] file for cleanup. The finalized shard generation is now [{}], for shard \ + snapshot [{}] with shard ID [{}] on node [{}] + """, + oldStatus.generation(), + newStatus.generation(), + entry.snapshot(), + repositoryShardId.shardId(), + oldStatus.nodeId() + ); } } } @@ -441,7 +456,9 @@ public SnapshotsInProgress withUpdatedNodeIdsForRemoval(ClusterState clusterStat updatedNodeIdsForRemoval.addAll(nodeIdsMarkedForRemoval); // remove any nodes which are no longer marked for shutdown if they have no running shard snapshots - updatedNodeIdsForRemoval.removeAll(getObsoleteNodeIdsForRemoval(nodeIdsMarkedForRemoval)); + var restoredNodeIds = getObsoleteNodeIdsForRemoval(nodeIdsMarkedForRemoval); + updatedNodeIdsForRemoval.removeAll(restoredNodeIds); + logger.debug("Resuming shard snapshots on nodes [{}]", restoredNodeIds); if (updatedNodeIdsForRemoval.equals(nodesIdsForRemoval)) { return this; @@ -469,19 +486,26 @@ private static Set getNodesIdsMarkedForRemoval(ClusterState clusterState return result; } + /** + * Identifies any nodes that are no longer marked for removal AND have no running shard snapshots. + * @param latestNodeIdsMarkedForRemoval the current nodes marked for removal in the cluster state. + */ private Set getObsoleteNodeIdsForRemoval(Set latestNodeIdsMarkedForRemoval) { - final var obsoleteNodeIdsForRemoval = new HashSet<>(nodesIdsForRemoval); - obsoleteNodeIdsForRemoval.removeIf(latestNodeIdsMarkedForRemoval::contains); - if (obsoleteNodeIdsForRemoval.isEmpty()) { + // Find any nodes no longer marked for removal. + final var nodeIdsNoLongerMarkedForRemoval = new HashSet<>(nodesIdsForRemoval); + nodeIdsNoLongerMarkedForRemoval.removeIf(latestNodeIdsMarkedForRemoval::contains); + if (nodeIdsNoLongerMarkedForRemoval.isEmpty()) { return Set.of(); } + // If any nodes have INIT state shard snapshots, then the node's snapshots are not concurrency safe to resume yet. All shard + // snapshots on a newly revived node (no longer marked for shutdown) must finish moving to paused before any can resume. for (final var byRepo : entries.values()) { for (final var entry : byRepo.entries()) { if (entry.state() == State.STARTED && entry.hasShardsInInitState()) { for (final var shardSnapshotStatus : entry.shards().values()) { if (shardSnapshotStatus.state() == ShardState.INIT) { - obsoleteNodeIdsForRemoval.remove(shardSnapshotStatus.nodeId()); - if (obsoleteNodeIdsForRemoval.isEmpty()) { + nodeIdsNoLongerMarkedForRemoval.remove(shardSnapshotStatus.nodeId()); + if (nodeIdsNoLongerMarkedForRemoval.isEmpty()) { return Set.of(); } } @@ -489,7 +513,7 @@ private Set getObsoleteNodeIdsForRemoval(Set latestNodeIdsMarked } } } - return obsoleteNodeIdsForRemoval; + return nodeIdsNoLongerMarkedForRemoval; } public boolean nodeIdsForRemovalChanged(SnapshotsInProgress other) { @@ -616,6 +640,9 @@ public record ShardSnapshotStatus( "missing index" ); + /** + * Initializes status with state {@link ShardState#INIT}. + */ public ShardSnapshotStatus(String nodeId, ShardGeneration generation) { this(nodeId, ShardState.INIT, generation); } diff --git a/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java b/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java index 2cac6ddb159bc..770ed4d213c55 100644 --- a/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java +++ b/server/src/main/java/org/elasticsearch/common/ReferenceDocs.java @@ -73,6 +73,10 @@ public enum ReferenceDocs { UNASSIGNED_SHARDS, EXECUTABLE_JNA_TMPDIR, NETWORK_THREADING_MODEL, + ALLOCATION_EXPLAIN_API, + NETWORK_BINDING_AND_PUBLISHING, + SNAPSHOT_REPOSITORY_ANALYSIS, + S3_COMPATIBLE_REPOSITORIES, // this comment keeps the ';' on the next line so every entry above has a trailing ',' which makes the diff for adding new links cleaner ; diff --git a/server/src/main/java/org/elasticsearch/common/compress/Compressor.java b/server/src/main/java/org/elasticsearch/common/compress/Compressor.java index 239f168306a94..400653a69a9be 100644 --- a/server/src/main/java/org/elasticsearch/common/compress/Compressor.java +++ b/server/src/main/java/org/elasticsearch/common/compress/Compressor.java @@ -26,7 +26,16 @@ public interface Compressor { */ default StreamInput threadLocalStreamInput(InputStream in) throws IOException { // wrap stream in buffer since InputStreamStreamInput doesn't do any buffering itself but does a lot of small reads - return new InputStreamStreamInput(new BufferedInputStream(threadLocalInputStream(in), DeflateCompressor.BUFFER_SIZE)); + return new InputStreamStreamInput(new BufferedInputStream(threadLocalInputStream(in), DeflateCompressor.BUFFER_SIZE) { + @Override + public int read() throws IOException { + // override read to avoid synchronized single byte reads now that JEP374 removed biased locking + if (pos >= count) { + return super.read(); + } + return buf[pos++] & 0xFF; + } + }); } /** diff --git a/server/src/main/java/org/elasticsearch/common/util/BitArray.java b/server/src/main/java/org/elasticsearch/common/util/BitArray.java index 53244a0f2888a..041111840056d 100644 --- a/server/src/main/java/org/elasticsearch/common/util/BitArray.java +++ b/server/src/main/java/org/elasticsearch/common/util/BitArray.java @@ -64,6 +64,17 @@ public void writeTo(StreamOutput out) throws IOException { bits.writeTo(out); } + /** + * Set or clear the {@code index}th bit based on the specified value. + */ + public void set(long index, boolean value) { + if (value) { + set(index); + } else { + clear(index); + } + } + /** * Set the {@code index}th bit. */ @@ -158,6 +169,68 @@ public boolean get(long index) { return (bits.get(wordNum) & bitmask) != 0; } + /** + * Set or clear slots between {@code fromIndex} inclusive to {@code toIndex} based on {@code value}. + */ + public void fill(long fromIndex, long toIndex, boolean value) { + if (fromIndex > toIndex) { + throw new IllegalArgumentException("From should be less than or equal to toIndex"); + } + long currentSize = size(); + if (value == false) { + // There's no need to grow the array just to clear bits. + toIndex = Math.min(toIndex, currentSize); + } + if (fromIndex == toIndex) { + return; // Empty range + } + + if (toIndex > currentSize) { + bits = bigArrays.grow(bits, wordNum(toIndex) + 1); + } + + int wordLength = Long.BYTES * Byte.SIZE; + long fullWord = 0xFFFFFFFFFFFFFFFFL; + + long firstWordIndex = fromIndex % wordLength; + long lastWordIndex = toIndex % wordLength; + + long firstWordNum = wordNum(fromIndex); + long lastWordNum = wordNum(toIndex - 1); + + // Mask first word + if (firstWordIndex > 0) { + long mask = fullWord << firstWordIndex; + + if (firstWordNum == lastWordNum) { + mask &= fullWord >>> (wordLength - lastWordIndex); + } + + if (value) { + bits.set(firstWordNum, bits.get(firstWordNum) | mask); + } else { + bits.set(firstWordNum, bits.get(firstWordNum) & ~mask); + } + + firstWordNum++; + } + + // Mask last word + if (firstWordNum <= lastWordNum) { + long mask = fullWord >>> (wordLength - lastWordIndex); + + if (value) { + bits.set(lastWordNum, bits.get(lastWordNum) | mask); + } else { + bits.set(lastWordNum, bits.get(lastWordNum) & ~mask); + } + } + + if (firstWordNum < lastWordNum) { + bits.fill(firstWordNum, lastWordNum, value ? fullWord : 0L); + } + } + public long size() { return bits.size() * (long) Long.BYTES * Byte.SIZE; } diff --git a/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java b/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java index 209faa7207be1..d234c1797e090 100644 --- a/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java +++ b/server/src/main/java/org/elasticsearch/discovery/HandshakingTransportAddressConnector.java @@ -12,9 +12,11 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.Version; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.VersionInformation; import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.ReferenceDocs; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; @@ -26,6 +28,7 @@ import org.elasticsearch.index.IndexVersions; import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.transport.ConnectionProfile; +import org.elasticsearch.transport.Transport; import org.elasticsearch.transport.TransportRequestOptions.Type; import org.elasticsearch.transport.TransportService; @@ -72,11 +75,26 @@ public HandshakingTransportAddressConnector(Settings settings, TransportService @Override public void connectToRemoteMasterNode(TransportAddress transportAddress, ActionListener listener) { - try { + new ConnectionAttempt(transportAddress).run(listener); + } + + private class ConnectionAttempt { + private final TransportAddress transportAddress; + + ConnectionAttempt(TransportAddress transportAddress) { + this.transportAddress = transportAddress; + } + + void run(ActionListener listener) { + SubscribableListener.newForked(this::openProbeConnection) + .andThen(this::handshakeProbeConnection) + .andThen(this::openFullConnection) + .addListener(listener); + } + private void openProbeConnection(ActionListener listener) { // We could skip this if the transportService were already connected to the given address, but the savings would be minimal so // we open a new connection anyway. - logger.trace("[{}] opening probe connection", transportAddress); transportService.openConnection( new DiscoveryNode( @@ -95,98 +113,91 @@ public void connectToRemoteMasterNode(TransportAddress transportAddress, ActionL ) ), handshakeConnectionProfile, - listener.delegateFailure((l, connection) -> { - logger.trace("[{}] opened probe connection", transportAddress); - final var probeHandshakeTimeout = handshakeConnectionProfile.getHandshakeTimeout(); - // use NotifyOnceListener to make sure the following line does not result in onFailure being called when - // the connection is closed in the onResponse handler - transportService.handshake(connection, probeHandshakeTimeout, ActionListener.notifyOnce(new ActionListener<>() { - - @Override - public void onResponse(DiscoveryNode remoteNode) { - try { - // success means (amongst other things) that the cluster names match - logger.trace("[{}] handshake successful: {}", transportAddress, remoteNode); - IOUtils.closeWhileHandlingException(connection); - - if (remoteNode.equals(transportService.getLocalNode())) { - listener.onFailure( - new ConnectTransportException( - remoteNode, - String.format( - Locale.ROOT, - "successfully discovered local node %s at [%s]", - remoteNode.descriptionWithoutAttributes(), - transportAddress - ) - ) - ); - } else if (remoteNode.isMasterNode() == false) { - listener.onFailure( - new ConnectTransportException( - remoteNode, - String.format( - Locale.ROOT, - """ - successfully discovered master-ineligible node %s at [%s]; to suppress this message, \ - remove address [%s] from your discovery configuration or ensure that traffic to this \ - address is routed only to master-eligible nodes""", - remoteNode.descriptionWithoutAttributes(), - transportAddress, - transportAddress - ) - ) - ); - } else { - transportService.connectToNode(remoteNode, new ActionListener<>() { - @Override - public void onResponse(Releasable connectionReleasable) { - logger.trace("[{}] completed full connection with [{}]", transportAddress, remoteNode); - listener.onResponse(new ProbeConnectionResult(remoteNode, connectionReleasable)); - } - - @Override - public void onFailure(Exception e) { - // we opened a connection and successfully performed a handshake, so we're definitely - // talking to a master-eligible node with a matching cluster name and a good version, but - // the attempt to open a full connection to its publish address failed; a common reason is - // that the remote node is listening on 0.0.0.0 but has made an inappropriate choice for its - // publish address. - logger.warn( - () -> format( - "completed handshake with [%s] at [%s] but followup connection to [%s] failed", - remoteNode.descriptionWithoutAttributes(), - transportAddress, - remoteNode.getAddress() - ), - e - ); - listener.onFailure(e); - } - }); - } - } catch (Exception e) { - listener.onFailure(e); - } - } - - @Override - public void onFailure(Exception e) { - // we opened a connection and successfully performed a low-level handshake, so we were definitely - // talking to an Elasticsearch node, but the high-level handshake failed indicating some kind of - // mismatched configurations (e.g. cluster name) that the user should address - logger.warn(() -> "handshake to [" + transportAddress + "] failed", e); - IOUtils.closeWhileHandlingException(connection); - listener.onFailure(e); - } - - })); - - }) + ActionListener.assertOnce(listener) ); + } + + private void handshakeProbeConnection(ActionListener listener, Transport.Connection connection) { + logger.trace("[{}] opened probe connection", transportAddress); + final var probeHandshakeTimeout = handshakeConnectionProfile.getHandshakeTimeout(); + transportService.handshake(connection, probeHandshakeTimeout, ActionListener.assertOnce(new ActionListener<>() { + @Override + public void onResponse(DiscoveryNode remoteNode) { + // success means (amongst other things) that the cluster names match + logger.trace("[{}] handshake successful: {}", transportAddress, remoteNode); + IOUtils.closeWhileHandlingException(connection); + listener.onResponse(remoteNode); + } + + @Override + public void onFailure(Exception e) { + // We opened a connection and successfully performed a low-level handshake, so we were definitely talking to an + // Elasticsearch node, but the high-level handshake failed indicating some kind of mismatched configurations (e.g. + // cluster name) that the user should address. + logger.warn(() -> "handshake to [" + transportAddress + "] failed", e); + IOUtils.closeWhileHandlingException(connection); + listener.onFailure(e); + } + })); + } - } catch (Exception e) { - listener.onFailure(e); + private void openFullConnection(ActionListener listener, DiscoveryNode remoteNode) { + if (remoteNode.equals(transportService.getLocalNode())) { + throw new ConnectTransportException( + remoteNode, + String.format( + Locale.ROOT, + "successfully discovered local node %s at [%s]", + remoteNode.descriptionWithoutAttributes(), + transportAddress + ) + ); + } + + if (remoteNode.isMasterNode() == false) { + throw new ConnectTransportException( + remoteNode, + String.format( + Locale.ROOT, + """ + successfully discovered master-ineligible node %s at [%s]; to suppress this message, remove address [%s] from \ + your discovery configuration or ensure that traffic to this address is routed only to master-eligible nodes""", + remoteNode.descriptionWithoutAttributes(), + transportAddress, + transportAddress + ) + ); + } + + transportService.connectToNode(remoteNode, ActionListener.assertOnce(new ActionListener<>() { + @Override + public void onResponse(Releasable connectionReleasable) { + logger.trace("[{}] completed full connection with [{}]", transportAddress, remoteNode); + listener.onResponse(new ProbeConnectionResult(remoteNode, connectionReleasable)); + } + + @Override + public void onFailure(Exception e) { + // We opened a connection and successfully performed a handshake, so we're definitely talking to a master-eligible node + // with a matching cluster name and a good version, but the attempt to open a full connection to its publish address + // failed; a common reason is that the remote node is listening on 0.0.0.0 but has made an inappropriate choice for its + // publish address. + logger.warn( + () -> format( + """ + Successfully discovered master-eligible node [%s] at address [%s] but could not connect to it at its \ + publish address of [%s]. Each node in a cluster must be accessible at its publish address by all other \ + nodes in the cluster. See %s for more information.""", + remoteNode.descriptionWithoutAttributes(), + transportAddress, + remoteNode.getAddress(), + ReferenceDocs.NETWORK_BINDING_AND_PUBLISHING + ), + e + ); + listener.onFailure(e); + } + })); } } } diff --git a/server/src/main/java/org/elasticsearch/discovery/PeerFinder.java b/server/src/main/java/org/elasticsearch/discovery/PeerFinder.java index 83660cede004e..11f3bbdc13bbf 100644 --- a/server/src/main/java/org/elasticsearch/discovery/PeerFinder.java +++ b/server/src/main/java/org/elasticsearch/discovery/PeerFinder.java @@ -12,6 +12,7 @@ import org.apache.logging.log4j.Logger; import org.apache.lucene.util.SetOnce; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ThreadedActionListener; import org.elasticsearch.cluster.coordination.ClusterFormationFailureHelper; import org.elasticsearch.cluster.coordination.PeersResponse; import org.elasticsearch.cluster.node.DiscoveryNode; @@ -413,86 +414,90 @@ void establishConnection() { - activatedAtMillis > verbosityIncreaseTimeout.millis(); logger.trace("{} attempting connection", this); - transportAddressConnector.connectToRemoteMasterNode(transportAddress, new ActionListener() { - @Override - public void onResponse(ProbeConnectionResult connectResult) { - assert holdsLock() == false : "PeerFinder mutex is held in error"; - final DiscoveryNode remoteNode = connectResult.getDiscoveryNode(); - assert remoteNode.isMasterNode() : remoteNode + " is not master-eligible"; - assert remoteNode.equals(getLocalNode()) == false : remoteNode + " is the local node"; - boolean retainConnection = false; - try { - synchronized (mutex) { - if (isActive() == false) { - logger.trace("Peer#establishConnection inactive: {}", Peer.this); - return; + transportAddressConnector.connectToRemoteMasterNode( + transportAddress, + // may be completed on the calling thread, and therefore under the mutex, so must always fork + new ThreadedActionListener<>(clusterCoordinationExecutor, new ActionListener<>() { + @Override + public void onResponse(ProbeConnectionResult connectResult) { + assert holdsLock() == false : "PeerFinder mutex is held in error"; + final DiscoveryNode remoteNode = connectResult.getDiscoveryNode(); + assert remoteNode.isMasterNode() : remoteNode + " is not master-eligible"; + assert remoteNode.equals(getLocalNode()) == false : remoteNode + " is the local node"; + boolean retainConnection = false; + try { + synchronized (mutex) { + if (isActive() == false) { + logger.trace("Peer#establishConnection inactive: {}", Peer.this); + return; + } + + assert probeConnectionResult.get() == null + : "connection result unexpectedly already set to " + probeConnectionResult.get(); + probeConnectionResult.set(connectResult); + + requestPeers(); } - assert probeConnectionResult.get() == null - : "connection result unexpectedly already set to " + probeConnectionResult.get(); - probeConnectionResult.set(connectResult); - - requestPeers(); - } - - onFoundPeersUpdated(); + onFoundPeersUpdated(); - retainConnection = true; - } finally { - if (retainConnection == false) { - Releasables.close(connectResult); + retainConnection = true; + } finally { + if (retainConnection == false) { + Releasables.close(connectResult); + } } } - } - @Override - public void onFailure(Exception e) { - if (verboseFailureLogging) { - - final String believedMasterBy; - synchronized (mutex) { - believedMasterBy = peersByAddress.values() - .stream() - .filter(p -> p.lastKnownMasterNode.map(DiscoveryNode::getAddress).equals(Optional.of(transportAddress))) - .findFirst() - .map(p -> " [current master according to " + p.getDiscoveryNode().descriptionWithoutAttributes() + "]") - .orElse(""); - } + @Override + public void onFailure(Exception e) { + if (verboseFailureLogging) { + + final String believedMasterBy; + synchronized (mutex) { + believedMasterBy = peersByAddress.values() + .stream() + .filter(p -> p.lastKnownMasterNode.map(DiscoveryNode::getAddress).equals(Optional.of(transportAddress))) + .findFirst() + .map(p -> " [current master according to " + p.getDiscoveryNode().descriptionWithoutAttributes() + "]") + .orElse(""); + } - if (logger.isDebugEnabled()) { - // log message at level WARN, but since DEBUG logging is enabled we include the full stack trace - logger.warn(() -> format("%s%s discovery result", Peer.this, believedMasterBy), e); - } else { - final StringBuilder messageBuilder = new StringBuilder(); - Throwable cause = e; - while (cause != null && messageBuilder.length() <= 1024) { - messageBuilder.append(": ").append(cause.getMessage()); - cause = cause.getCause(); + if (logger.isDebugEnabled()) { + // log message at level WARN, but since DEBUG logging is enabled we include the full stack trace + logger.warn(() -> format("%s%s discovery result", Peer.this, believedMasterBy), e); + } else { + final StringBuilder messageBuilder = new StringBuilder(); + Throwable cause = e; + while (cause != null && messageBuilder.length() <= 1024) { + messageBuilder.append(": ").append(cause.getMessage()); + cause = cause.getCause(); + } + final String message = messageBuilder.length() < 1024 + ? messageBuilder.toString() + : (messageBuilder.substring(0, 1023) + "..."); + logger.warn( + "{}{} discovery result{}; for summary, see logs from {}; for troubleshooting guidance, see {}", + Peer.this, + believedMasterBy, + message, + ClusterFormationFailureHelper.class.getCanonicalName(), + ReferenceDocs.DISCOVERY_TROUBLESHOOTING + ); } - final String message = messageBuilder.length() < 1024 - ? messageBuilder.toString() - : (messageBuilder.substring(0, 1023) + "..."); - logger.warn( - "{}{} discovery result{}; for summary, see logs from {}; for troubleshooting guidance, see {}", - Peer.this, - believedMasterBy, - message, - ClusterFormationFailureHelper.class.getCanonicalName(), - ReferenceDocs.DISCOVERY_TROUBLESHOOTING - ); + } else { + logger.debug(() -> format("%s discovery result", Peer.this), e); + } + synchronized (mutex) { + assert probeConnectionResult.get() == null + : "discoveryNode unexpectedly already set to " + probeConnectionResult.get(); + if (isActive()) { + peersByAddress.remove(transportAddress); + } // else this Peer has been superseded by a different instance which should be left in place } - } else { - logger.debug(() -> format("%s discovery result", Peer.this), e); - } - synchronized (mutex) { - assert probeConnectionResult.get() == null - : "discoveryNode unexpectedly already set to " + probeConnectionResult.get(); - if (isActive()) { - peersByAddress.remove(transportAddress); - } // else this Peer has been superseded by a different instance which should be left in place } - } - }); + }) + ); } private void requestPeers() { diff --git a/server/src/main/java/org/elasticsearch/index/codec/DeduplicatingFieldInfosFormat.java b/server/src/main/java/org/elasticsearch/index/codec/DeduplicatingFieldInfosFormat.java new file mode 100644 index 0000000000000..75ec265a68391 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/index/codec/DeduplicatingFieldInfosFormat.java @@ -0,0 +1,96 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.index.codec; + +import org.apache.lucene.codecs.FieldInfosFormat; +import org.apache.lucene.index.FieldInfo; +import org.apache.lucene.index.FieldInfos; +import org.apache.lucene.index.SegmentInfo; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.elasticsearch.common.util.Maps; +import org.elasticsearch.common.util.StringLiteralDeduplicator; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.index.mapper.FieldMapper; + +import java.io.IOException; +import java.util.Map; + +/** + * Wrapper around a {@link FieldInfosFormat} that will deduplicate and intern all field names, attribute-keys and -values, and in most + * cases attribute maps on read. We use this to reduce the per-field overhead for Elasticsearch instances holding a large number of + * segments. + */ +public final class DeduplicatingFieldInfosFormat extends FieldInfosFormat { + + private static final Map, Map> attributeDeduplicator = ConcurrentCollections.newConcurrentMap(); + + private static final StringLiteralDeduplicator attributesDeduplicator = new StringLiteralDeduplicator(); + + private final FieldInfosFormat delegate; + + public DeduplicatingFieldInfosFormat(FieldInfosFormat delegate) { + this.delegate = delegate; + } + + @Override + public FieldInfos read(Directory directory, SegmentInfo segmentInfo, String segmentSuffix, IOContext iocontext) throws IOException { + final FieldInfos fieldInfos = delegate.read(directory, segmentInfo, segmentSuffix, iocontext); + final FieldInfo[] deduplicated = new FieldInfo[fieldInfos.size()]; + int i = 0; + for (FieldInfo fi : fieldInfos) { + deduplicated[i++] = new FieldInfo( + FieldMapper.internFieldName(fi.getName()), + fi.number, + fi.hasVectors(), + fi.omitsNorms(), + fi.hasPayloads(), + fi.getIndexOptions(), + fi.getDocValuesType(), + fi.getDocValuesGen(), + internStringStringMap(fi.attributes()), + fi.getPointDimensionCount(), + fi.getPointIndexDimensionCount(), + fi.getPointNumBytes(), + fi.getVectorDimension(), + fi.getVectorEncoding(), + fi.getVectorSimilarityFunction(), + fi.isSoftDeletesField(), + fi.isParentField() + ); + } + return new FieldInfos(deduplicated); + } + + private static Map internStringStringMap(Map m) { + if (m.size() > 10) { + return m; + } + var res = attributeDeduplicator.get(m); + if (res == null) { + if (attributeDeduplicator.size() > 100) { + // Unexpected edge case to have more than 100 different attribute maps + // Just to be safe, don't retain more than 100 maps to prevent a potential memory leak + attributeDeduplicator.clear(); + } + final Map interned = Maps.newHashMapWithExpectedSize(m.size()); + m.forEach((key, value) -> interned.put(attributesDeduplicator.deduplicate(key), attributesDeduplicator.deduplicate(value))); + res = Map.copyOf(interned); + attributeDeduplicator.put(res, res); + } + return res; + } + + @Override + public void write(Directory directory, SegmentInfo segmentInfo, String segmentSuffix, FieldInfos infos, IOContext context) + throws IOException { + delegate.write(directory, segmentInfo, segmentSuffix, infos, context); + } + +} diff --git a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch814Codec.java b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch814Codec.java index e85e05c87b083..dd7a668605e57 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch814Codec.java +++ b/server/src/main/java/org/elasticsearch/index/codec/Elasticsearch814Codec.java @@ -9,6 +9,7 @@ package org.elasticsearch.index.codec; import org.apache.lucene.codecs.DocValuesFormat; +import org.apache.lucene.codecs.FieldInfosFormat; import org.apache.lucene.codecs.FilterCodec; import org.apache.lucene.codecs.KnnVectorsFormat; import org.apache.lucene.codecs.PostingsFormat; @@ -30,6 +31,8 @@ public class Elasticsearch814Codec extends FilterCodec { private final StoredFieldsFormat storedFieldsFormat; + private final FieldInfosFormat fieldInfosFormat; + private final PostingsFormat defaultPostingsFormat; private final PostingsFormat postingsFormat = new PerFieldPostingsFormat() { @Override @@ -69,6 +72,7 @@ public Elasticsearch814Codec(Zstd814StoredFieldsFormat.Mode mode) { this.defaultPostingsFormat = new Lucene99PostingsFormat(); this.defaultDVFormat = new Lucene90DocValuesFormat(); this.defaultKnnVectorsFormat = new Lucene99HnswVectorsFormat(); + this.fieldInfosFormat = new DeduplicatingFieldInfosFormat(delegate.fieldInfosFormat()); } @Override @@ -127,4 +131,9 @@ public DocValuesFormat getDocValuesFormatForField(String field) { public KnnVectorsFormat getKnnVectorsFormatForField(String field) { return defaultKnnVectorsFormat; } + + @Override + public FieldInfosFormat fieldInfosFormat() { + return fieldInfosFormat; + } } diff --git a/server/src/main/java/org/elasticsearch/index/codec/bloomfilter/ES87BloomFilterPostingsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/bloomfilter/ES87BloomFilterPostingsFormat.java index 191fe8f75b2f0..01d874adec14d 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/bloomfilter/ES87BloomFilterPostingsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/bloomfilter/ES87BloomFilterPostingsFormat.java @@ -128,7 +128,6 @@ final class FieldsWriter extends FieldsConsumer { private final List fieldsGroups = new ArrayList<>(); private final List toCloses = new ArrayList<>(); private boolean closed; - private final int[] hashes = new int[NUM_HASH_FUNCTIONS]; FieldsWriter(SegmentWriteState state) throws IOException { this.state = state; @@ -180,23 +179,24 @@ public Iterator iterator() { } private void writeBloomFilters(Fields fields) throws IOException { - for (String field : fields) { - final Terms terms = fields.terms(field); - if (terms == null) { - continue; - } - final int bloomFilterSize = bloomFilterSize(state.segmentInfo.maxDoc()); - final int numBytes = numBytesForBloomFilter(bloomFilterSize); - try (ByteArray buffer = bigArrays.newByteArray(numBytes)) { + final int bloomFilterSize = bloomFilterSize(state.segmentInfo.maxDoc()); + final int numBytes = numBytesForBloomFilter(bloomFilterSize); + final int[] hashes = new int[NUM_HASH_FUNCTIONS]; + try (ByteArray buffer = bigArrays.newByteArray(numBytes, false)) { + long written = indexOut.getFilePointer(); + for (String field : fields) { + final Terms terms = fields.terms(field); + if (terms == null) { + continue; + } + buffer.fill(0, numBytes, (byte) 0); final TermsEnum termsEnum = terms.iterator(); while (true) { final BytesRef term = termsEnum.next(); if (term == null) { break; } - - hashTerm(term, hashes); - for (int hash : hashes) { + for (int hash : hashTerm(term, hashes)) { hash = hash % bloomFilterSize; final int pos = hash >> 3; final int mask = 1 << (hash & 7); @@ -204,9 +204,13 @@ private void writeBloomFilters(Fields fields) throws IOException { buffer.set(pos, val); } } - bloomFilters.add(new BloomFilter(field, indexOut.getFilePointer(), bloomFilterSize)); - final BytesReference bytes = BytesReference.fromByteArray(buffer, numBytes); - bytes.writeTo(new IndexOutputOutputStream(indexOut)); + bloomFilters.add(new BloomFilter(field, written, bloomFilterSize)); + if (buffer.hasArray()) { + indexOut.writeBytes(buffer.array(), 0, numBytes); + } else { + BytesReference.fromByteArray(buffer, numBytes).writeTo(new IndexOutputOutputStream(indexOut)); + } + written += numBytes; } } } @@ -636,35 +640,10 @@ private MurmurHash3() {} * @param length The length of array * @return The sum of the two 64-bit hashes that make up the hash128 */ - public static long hash64(final byte[] data, final int offset, final int length) { - // We hope that the C2 escape analysis prevents ths allocation from creating GC pressure. - long[] hash128 = { 0, 0 }; - hash128x64Internal(data, offset, length, DEFAULT_SEED, hash128); - return hash128[0]; - } - - /** - * Generates 128-bit hash from the byte array with the given offset, length and seed. - * - *

This is an implementation of the 128-bit hash function {@code MurmurHash3_x64_128} - * from Austin Appleby's original MurmurHash3 {@code c++} code in SMHasher.

- * - * @param data The input byte array - * @param offset The first element of array - * @param length The length of array - * @param seed The initial seed value - * @return The 128-bit hash (2 longs) - */ @SuppressWarnings("fallthrough") - private static long[] hash128x64Internal( - final byte[] data, - final int offset, - final int length, - final long seed, - final long[] result - ) { - long h1 = seed; - long h2 = seed; + public static long hash64(final byte[] data, final int offset, final int length) { + long h1 = MurmurHash3.DEFAULT_SEED; + long h2 = MurmurHash3.DEFAULT_SEED; final int nblocks = length >> 4; // body @@ -749,11 +728,8 @@ private static long[] hash128x64Internal( h2 = fmix64(h2); h1 += h2; - h2 += h1; - result[0] = h1; - result[1] = h2; - return result; + return h1; } /** diff --git a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java index 659cc89bfe46d..de91833c99842 100644 --- a/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java +++ b/server/src/main/java/org/elasticsearch/index/codec/vectors/ES815BitFlatVectorsFormat.java @@ -16,11 +16,11 @@ import org.apache.lucene.index.SegmentReadState; import org.apache.lucene.index.SegmentWriteState; import org.apache.lucene.index.VectorSimilarityFunction; -import org.apache.lucene.util.VectorUtil; import org.apache.lucene.util.hnsw.RandomAccessVectorValues; import org.apache.lucene.util.hnsw.RandomVectorScorer; import org.apache.lucene.util.hnsw.RandomVectorScorerSupplier; import org.apache.lucene.util.quantization.RandomAccessQuantizedByteVectorValues; +import org.elasticsearch.script.field.vectors.ESVectorUtil; import java.io.IOException; @@ -100,7 +100,7 @@ public RandomVectorScorer getRandomVectorScorer( } static float hammingScore(byte[] a, byte[] b) { - return ((a.length * Byte.SIZE) - VectorUtil.xorBitCount(a, b)) / (float) (a.length * Byte.SIZE); + return ((a.length * Byte.SIZE) - ESVectorUtil.xorBitCount(a, b)) / (float) (a.length * Byte.SIZE); } static class HammingVectorScorer extends RandomVectorScorer.AbstractRandomVectorScorer { diff --git a/server/src/main/java/org/elasticsearch/index/engine/CombinedDeletionPolicy.java b/server/src/main/java/org/elasticsearch/index/engine/CombinedDeletionPolicy.java index a69cc42163dd2..22bab1742589e 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/CombinedDeletionPolicy.java +++ b/server/src/main/java/org/elasticsearch/index/engine/CombinedDeletionPolicy.java @@ -153,7 +153,7 @@ private SafeCommitInfo getNewSafeCommitInfo(IndexCommit newSafeCommit) { return currentSafeCommitInfo; } - if (currentSafeCommitInfo.localCheckpoint == newSafeCommitLocalCheckpoint) { + if (currentSafeCommitInfo.localCheckpoint() == newSafeCommitLocalCheckpoint) { // the new commit could in principle have the same LCP but a different doc count due to extra operations between its LCP and // MSN, but that is a transient state since we'll eventually advance the LCP. The doc count is only used for heuristics around // expiring excessively-lagging retention leases, so a little inaccuracy is tolerable here. @@ -164,7 +164,7 @@ private SafeCommitInfo getNewSafeCommitInfo(IndexCommit newSafeCommit) { return new SafeCommitInfo(newSafeCommitLocalCheckpoint, getDocCountOfCommit(newSafeCommit)); } catch (IOException ex) { logger.info("failed to get the total docs from the safe commit; use the total docs from the previous safe commit", ex); - return new SafeCommitInfo(newSafeCommitLocalCheckpoint, currentSafeCommitInfo.docCount); + return new SafeCommitInfo(newSafeCommitLocalCheckpoint, currentSafeCommitInfo.docCount()); } } diff --git a/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java b/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java index a991c5544a1e1..03d244cd8e4ef 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/InternalEngine.java @@ -344,8 +344,8 @@ private LocalCheckpointTracker createLocalCheckpointTracker( final SequenceNumbers.CommitInfo seqNoStats = SequenceNumbers.loadSeqNoInfoFromLuceneCommit( store.readLastCommittedSegmentsInfo().userData.entrySet() ); - maxSeqNo = seqNoStats.maxSeqNo; - localCheckpoint = seqNoStats.localCheckpoint; + maxSeqNo = seqNoStats.maxSeqNo(); + localCheckpoint = seqNoStats.localCheckpoint(); logger.trace("recovered maximum sequence number [{}] and local checkpoint [{}]", maxSeqNo, localCheckpoint); return localCheckpointTrackerSupplier.apply(maxSeqNo, localCheckpoint); } @@ -1688,7 +1688,7 @@ private Exception tryAcquireInFlightDocs(Operation operation, int addingDocs) { final long totalDocs = indexWriter.getPendingNumDocs() + inFlightDocCount.addAndGet(addingDocs); if (totalDocs > maxDocs) { releaseInFlightDocs(addingDocs); - return new IllegalArgumentException("Number of documents in the index can't exceed [" + maxDocs + "]"); + return new IllegalArgumentException("Number of documents in the shard cannot exceed [" + maxDocs + "]"); } else { return null; } @@ -2143,9 +2143,8 @@ private boolean shouldPeriodicallyFlush(long flushThresholdSizeInBytes, long flu final long localCheckpointOfLastCommit = Long.parseLong( lastCommittedSegmentInfos.userData.get(SequenceNumbers.LOCAL_CHECKPOINT_KEY) ); - final long translogGenerationOfLastCommit = translog.getMinGenerationForSeqNo( - localCheckpointOfLastCommit + 1 - ).translogFileGeneration; + final long translogGenerationOfLastCommit = translog.getMinGenerationForSeqNo(localCheckpointOfLastCommit + 1) + .translogFileGeneration(); if (translog.sizeInBytesByMinGen(translogGenerationOfLastCommit) < flushThresholdSizeInBytes && relativeTimeInNanosSupplier.getAsLong() - lastFlushTimestamp < flushThresholdAgeInNanos) { return false; @@ -2165,9 +2164,8 @@ private boolean shouldPeriodicallyFlush(long flushThresholdSizeInBytes, long flu * * This method is to maintain translog only, thus IndexWriter#hasUncommittedChanges condition is not considered. */ - final long translogGenerationOfNewCommit = translog.getMinGenerationForSeqNo( - localCheckpointTracker.getProcessedCheckpoint() + 1 - ).translogFileGeneration; + final long translogGenerationOfNewCommit = translog.getMinGenerationForSeqNo(localCheckpointTracker.getProcessedCheckpoint() + 1) + .translogFileGeneration(); return translogGenerationOfLastCommit < translogGenerationOfNewCommit || localCheckpointTracker.getProcessedCheckpoint() == localCheckpointTracker.getMaxSeqNo(); } diff --git a/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java b/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java index eda408a9c8fde..c9474b58ef447 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java +++ b/server/src/main/java/org/elasticsearch/index/engine/ReadOnlyEngine.java @@ -244,8 +244,8 @@ protected void closeNoLock(String reason, CountDownLatch closedLatch) { private static SeqNoStats buildSeqNoStats(EngineConfig config, SegmentInfos infos) { final SequenceNumbers.CommitInfo seqNoStats = SequenceNumbers.loadSeqNoInfoFromLuceneCommit(infos.userData.entrySet()); - long maxSeqNo = seqNoStats.maxSeqNo; - long localCheckpoint = seqNoStats.localCheckpoint; + long maxSeqNo = seqNoStats.maxSeqNo(); + long localCheckpoint = seqNoStats.localCheckpoint(); return new SeqNoStats(maxSeqNo, localCheckpoint, config.getGlobalCheckpointSupplier().getAsLong()); } diff --git a/server/src/main/java/org/elasticsearch/index/engine/SafeCommitInfo.java b/server/src/main/java/org/elasticsearch/index/engine/SafeCommitInfo.java index 6858315f5b37f..5b206ecfd90dc 100644 --- a/server/src/main/java/org/elasticsearch/index/engine/SafeCommitInfo.java +++ b/server/src/main/java/org/elasticsearch/index/engine/SafeCommitInfo.java @@ -12,15 +12,6 @@ /** * Information about the safe commit, for making decisions about recoveries. */ -public class SafeCommitInfo { - - public final long localCheckpoint; - public final int docCount; - - public SafeCommitInfo(long localCheckpoint, int docCount) { - this.localCheckpoint = localCheckpoint; - this.docCount = docCount; - } - +public record SafeCommitInfo(long localCheckpoint, int docCount) { public static final SafeCommitInfo EMPTY = new SafeCommitInfo(SequenceNumbers.NO_OPS_PERFORMED, 0); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/AbstractGeometryFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/AbstractGeometryFieldMapper.java index 831244a3969ef..39f4a3a82c5c4 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/AbstractGeometryFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/AbstractGeometryFieldMapper.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.CheckedBiConsumer; import org.elasticsearch.common.Explicit; import org.elasticsearch.common.geo.GeometryFormatterFactory; +import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.core.CheckedConsumer; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.xcontent.DeprecationHandler; @@ -35,6 +36,12 @@ */ public abstract class AbstractGeometryFieldMapper extends FieldMapper { + // The GeoShapeFieldMapper class does not exist in server any more. + // For backwards compatibility we add the name of the class manually. + protected static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger( + "org.elasticsearch.index.mapper.GeoShapeFieldMapper" + ); + public static Parameter> ignoreMalformedParam( Function> initializer, boolean ignoreMalformedByDefault diff --git a/server/src/main/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapper.java index 56f1faeb38a5b..619c6c6613d59 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/AbstractShapeGeometryFieldMapper.java @@ -14,7 +14,7 @@ import java.util.function.Function; /** - * Base class for {@link GeoShapeFieldMapper} + * Base class for shape field mappers */ public abstract class AbstractShapeGeometryFieldMapper extends AbstractGeometryFieldMapper { @Override diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java index d8fa2919b795f..248369b249007 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DocumentParserContext.java @@ -673,7 +673,7 @@ public final MapperBuilderContext createDynamicMapperBuilderContext() { return new MapperBuilderContext( p, mappingLookup.isSourceSynthetic(), - false, + mappingLookup.isDataStreamTimestampFieldEnabled(), containsDimensions, dynamic, MergeReason.MAPPING_UPDATE, diff --git a/server/src/main/java/org/elasticsearch/index/mapper/GeoShapeFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/GeoShapeFieldMapper.java deleted file mode 100644 index 32d734a622eed..0000000000000 --- a/server/src/main/java/org/elasticsearch/index/mapper/GeoShapeFieldMapper.java +++ /dev/null @@ -1,211 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ -package org.elasticsearch.index.mapper; - -import org.apache.lucene.document.LatLonShape; -import org.apache.lucene.geo.LatLonGeometry; -import org.apache.lucene.search.Query; -import org.elasticsearch.common.Explicit; -import org.elasticsearch.common.geo.GeometryFormatterFactory; -import org.elasticsearch.common.geo.GeometryParser; -import org.elasticsearch.common.geo.Orientation; -import org.elasticsearch.common.geo.ShapeRelation; -import org.elasticsearch.common.logging.DeprecationCategory; -import org.elasticsearch.common.logging.DeprecationLogger; -import org.elasticsearch.geometry.Geometry; -import org.elasticsearch.index.IndexVersions; -import org.elasticsearch.index.query.QueryShardException; -import org.elasticsearch.index.query.SearchExecutionContext; - -import java.util.List; -import java.util.Map; -import java.util.function.Function; - -/** - * FieldMapper for indexing {@link LatLonShape}s. - *

- * Currently Shapes can only be indexed and can only be queried using - * {@link org.elasticsearch.index.query.GeoShapeQueryBuilder}, consequently - * a lot of behavior in this Mapper is disabled. - *

- * Format supported: - *

- * "field" : { - * "type" : "polygon", - * "coordinates" : [ - * [ [100.0, 0.0], [101.0, 0.0], [101.0, 1.0], [100.0, 1.0], [100.0, 0.0] ] - * ] - * } - *

- * or: - *

- * "field" : "POLYGON ((100.0 0.0, 101.0 0.0, 101.0 1.0, 100.0 1.0, 100.0 0.0)) - */ -public class GeoShapeFieldMapper extends AbstractShapeGeometryFieldMapper { - - private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(GeoShapeFieldMapper.class); - - public static final String CONTENT_TYPE = "geo_shape"; - - private static Builder builder(FieldMapper in) { - return ((GeoShapeFieldMapper) in).builder; - } - - public static class Builder extends FieldMapper.Builder { - - final Parameter indexed = Parameter.indexParam(m -> builder(m).indexed.get(), true); - - final Parameter> ignoreMalformed; - final Parameter> ignoreZValue = ignoreZValueParam(m -> builder(m).ignoreZValue.get()); - final Parameter> coerce; - final Parameter> orientation = orientationParam(m -> builder(m).orientation.get()); - - final Parameter> meta = Parameter.metaParam(); - - public Builder(String name, boolean ignoreMalformedByDefault, boolean coerceByDefault) { - super(name); - this.ignoreMalformed = ignoreMalformedParam(m -> builder(m).ignoreMalformed.get(), ignoreMalformedByDefault); - this.coerce = coerceParam(m -> builder(m).coerce.get(), coerceByDefault); - } - - public Builder ignoreZValue(boolean ignoreZValue) { - this.ignoreZValue.setValue(Explicit.explicitBoolean(ignoreZValue)); - return this; - } - - @Override - protected Parameter[] getParameters() { - return new Parameter[] { indexed, ignoreMalformed, ignoreZValue, coerce, orientation, meta }; - } - - @Override - public GeoShapeFieldMapper build(MapperBuilderContext context) { - if (multiFieldsBuilder.hasMultiFields()) { - DEPRECATION_LOGGER.warn( - DeprecationCategory.MAPPINGS, - "geo_shape_multifields", - "Adding multifields to [geo_shape] mappers has no effect and will be forbidden in future" - ); - } - GeometryParser geometryParser = new GeometryParser( - orientation.get().value().getAsBoolean(), - coerce.get().value(), - ignoreZValue.get().value() - ); - GeoShapeParser geoShapeParser = new GeoShapeParser(geometryParser, orientation.get().value()); - GeoShapeFieldType ft = new GeoShapeFieldType( - context.buildFullName(leafName()), - indexed.get(), - orientation.get().value(), - geoShapeParser, - meta.get() - ); - return new GeoShapeFieldMapper( - leafName(), - ft, - multiFieldsBuilder.build(this, context), - copyTo, - new GeoShapeIndexer(orientation.get().value(), context.buildFullName(leafName())), - geoShapeParser, - this - ); - } - } - - public static class GeoShapeFieldType extends AbstractShapeGeometryFieldType implements GeoShapeQueryable { - - public GeoShapeFieldType(String name, boolean indexed, Orientation orientation, Parser parser, Map meta) { - super(name, indexed, false, false, parser, orientation, meta); - } - - @Override - public String typeName() { - return CONTENT_TYPE; - } - - @Override - public Query geoShapeQuery(SearchExecutionContext context, String fieldName, ShapeRelation relation, LatLonGeometry... geometries) { - // CONTAINS queries are not supported by VECTOR strategy for indices created before version 7.5.0 (Lucene 8.3.0) - if (relation == ShapeRelation.CONTAINS && context.indexVersionCreated().before(IndexVersions.V_7_5_0)) { - throw new QueryShardException( - context, - ShapeRelation.CONTAINS + " query relation not supported for Field [" + fieldName + "]." - ); - } - return LatLonShape.newGeometryQuery(fieldName, relation.getLuceneRelation(), geometries); - } - - @Override - protected Function, List> getFormatter(String format) { - return GeometryFormatterFactory.getFormatter(format, Function.identity()); - } - } - - @Deprecated - public static Mapper.TypeParser PARSER = (name, node, parserContext) -> { - boolean ignoreMalformedByDefault = IGNORE_MALFORMED_SETTING.get(parserContext.getSettings()); - boolean coerceByDefault = COERCE_SETTING.get(parserContext.getSettings()); - FieldMapper.Builder builder = new Builder(name, ignoreMalformedByDefault, coerceByDefault); - builder.parse(name, parserContext, node); - return builder; - }; - - private final Builder builder; - private final GeoShapeIndexer indexer; - - public GeoShapeFieldMapper( - String simpleName, - MappedFieldType mappedFieldType, - MultiFields multiFields, - CopyTo copyTo, - GeoShapeIndexer indexer, - Parser parser, - Builder builder - ) { - super( - simpleName, - mappedFieldType, - builder.ignoreMalformed.get(), - builder.coerce.get(), - builder.ignoreZValue.get(), - builder.orientation.get(), - multiFields, - copyTo, - parser - ); - this.builder = builder; - this.indexer = indexer; - } - - @Override - public FieldMapper.Builder getMergeBuilder() { - return new Builder(leafName(), builder.ignoreMalformed.getDefaultValue().value(), builder.coerce.getDefaultValue().value()).init( - this - ); - } - - @Override - protected void index(DocumentParserContext context, Geometry geometry) { - if (geometry == null) { - return; - } - context.doc().addAll(indexer.indexShape(geometry)); - context.addToFieldNames(fieldType().name()); - } - - @Override - public GeoShapeFieldType fieldType() { - return (GeoShapeFieldType) super.fieldType(); - } - - @Override - protected String contentType() { - return CONTENT_TYPE; - } - -} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java index 989c92e909ce2..d27c0acdb6b2e 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/vectors/DenseVectorFieldMapper.java @@ -98,7 +98,7 @@ public class DenseVectorFieldMapper extends FieldMapper { public static final String COSINE_MAGNITUDE_FIELD_SUFFIX = "._magnitude"; private static final float EPS = 1e-3f; - static boolean isNotUnitVector(float magnitude) { + public static boolean isNotUnitVector(float magnitude) { return Math.abs(magnitude - 1.0f) > EPS; } diff --git a/server/src/main/java/org/elasticsearch/index/seqno/ReplicationTracker.java b/server/src/main/java/org/elasticsearch/index/seqno/ReplicationTracker.java index 0b3b15670ef78..247c2fd70761e 100644 --- a/server/src/main/java/org/elasticsearch/index/seqno/ReplicationTracker.java +++ b/server/src/main/java/org/elasticsearch/index/seqno/ReplicationTracker.java @@ -280,7 +280,7 @@ public synchronized RetentionLeases getRetentionLeases(final boolean expireLease private long getMinimumReasonableRetainedSeqNo() { final SafeCommitInfo safeCommitInfo = safeCommitInfoSupplier.get(); - return safeCommitInfo.localCheckpoint + 1 - Math.round(Math.ceil(safeCommitInfo.docCount * fileBasedRecoveryThreshold)); + return safeCommitInfo.localCheckpoint() + 1 - Math.round(Math.ceil(safeCommitInfo.docCount() * fileBasedRecoveryThreshold)); // NB safeCommitInfo.docCount is a very low-level count of the docs in the index, and in particular if this shard contains nested // docs then safeCommitInfo.docCount counts every child doc separately from the parent doc. However every part of a nested document // has the same seqno, so we may be overestimating the cost of a file-based recovery when compared to an ops-based recovery and diff --git a/server/src/main/java/org/elasticsearch/index/seqno/SequenceNumbers.java b/server/src/main/java/org/elasticsearch/index/seqno/SequenceNumbers.java index 0cd451f6be2cf..bb4ef40d28129 100644 --- a/server/src/main/java/org/elasticsearch/index/seqno/SequenceNumbers.java +++ b/server/src/main/java/org/elasticsearch/index/seqno/SequenceNumbers.java @@ -103,15 +103,7 @@ public static long max(final long maxSeqNo, final long seqNo) { } } - public static final class CommitInfo { - public final long maxSeqNo; - public final long localCheckpoint; - - public CommitInfo(long maxSeqNo, long localCheckpoint) { - this.maxSeqNo = maxSeqNo; - this.localCheckpoint = localCheckpoint; - } - + public record CommitInfo(long maxSeqNo, long localCheckpoint) { @Override public String toString() { return "CommitInfo{maxSeqNo=" + maxSeqNo + ", localCheckpoint=" + localCheckpoint + '}'; diff --git a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java index 881f4602be1c7..73cbca36a69c8 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java +++ b/server/src/main/java/org/elasticsearch/index/shard/IndexShard.java @@ -1856,8 +1856,8 @@ private void doLocalRecovery( return; } - assert safeCommit.get().localCheckpoint <= globalCheckpoint : safeCommit.get().localCheckpoint + " > " + globalCheckpoint; - if (safeCommit.get().localCheckpoint == globalCheckpoint) { + assert safeCommit.get().localCheckpoint() <= globalCheckpoint : safeCommit.get().localCheckpoint() + " > " + globalCheckpoint; + if (safeCommit.get().localCheckpoint() == globalCheckpoint) { logger.trace( "skip local recovery as the safe commit is up to date; safe commit {} global checkpoint {}", safeCommit.get(), @@ -1876,7 +1876,7 @@ private void doLocalRecovery( globalCheckpoint ); recoveryState.getTranslog().totalLocal(0); - recoveryStartingSeqNoListener.onResponse(safeCommit.get().localCheckpoint + 1); + recoveryStartingSeqNoListener.onResponse(safeCommit.get().localCheckpoint() + 1); return; } @@ -1915,7 +1915,7 @@ private void doLocalRecovery( // we need to find the safe commit again as we should have created a new one during the local recovery final Optional newSafeCommit = store.findSafeIndexCommit(globalCheckpoint); assert newSafeCommit.isPresent() : "no safe commit found after local recovery"; - return newSafeCommit.get().localCheckpoint + 1; + return newSafeCommit.get().localCheckpoint() + 1; } catch (Exception e) { logger.debug( () -> format( diff --git a/server/src/main/java/org/elasticsearch/index/shard/RemoveCorruptedShardDataCommand.java b/server/src/main/java/org/elasticsearch/index/shard/RemoveCorruptedShardDataCommand.java index ace891f9aead6..3783b64a0a04f 100644 --- a/server/src/main/java/org/elasticsearch/index/shard/RemoveCorruptedShardDataCommand.java +++ b/server/src/main/java/org/elasticsearch/index/shard/RemoveCorruptedShardDataCommand.java @@ -396,7 +396,7 @@ protected static void addNewHistoryCommit(Directory indexDirectory, Terminal ter // We can only safely do it because we will generate a new history uuid this shard. final SequenceNumbers.CommitInfo commitInfo = SequenceNumbers.loadSeqNoInfoFromLuceneCommit(userData.entrySet()); // Also advances the local checkpoint of the last commit to its max_seqno. - userData.put(SequenceNumbers.LOCAL_CHECKPOINT_KEY, Long.toString(commitInfo.maxSeqNo)); + userData.put(SequenceNumbers.LOCAL_CHECKPOINT_KEY, Long.toString(commitInfo.maxSeqNo())); } // commit the new history id diff --git a/server/src/main/java/org/elasticsearch/index/store/FsDirectoryFactory.java b/server/src/main/java/org/elasticsearch/index/store/FsDirectoryFactory.java index 37150ea748225..05c3554b47602 100644 --- a/server/src/main/java/org/elasticsearch/index/store/FsDirectoryFactory.java +++ b/server/src/main/java/org/elasticsearch/index/store/FsDirectoryFactory.java @@ -21,7 +21,6 @@ import org.apache.lucene.store.SimpleFSLockFactory; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Setting.Property; -import org.elasticsearch.common.util.FeatureFlag; import org.elasticsearch.core.IOUtils; import org.elasticsearch.index.IndexModule; import org.elasticsearch.index.IndexSettings; @@ -36,8 +35,6 @@ public class FsDirectoryFactory implements IndexStorePlugin.DirectoryFactory { - private static final FeatureFlag MADV_RANDOM_FEATURE_FLAG = new FeatureFlag("madv_random"); - public static final Setting INDEX_LOCK_FACTOR_SETTING = new Setting<>("index.store.fs.fs_lock", "native", (s) -> { return switch (s) { case "native" -> NativeFSLockFactory.INSTANCE; @@ -69,20 +66,12 @@ protected Directory newFSDirectory(Path location, LockFactory lockFactory, Index // Use Lucene defaults final FSDirectory primaryDirectory = FSDirectory.open(location, lockFactory); if (primaryDirectory instanceof MMapDirectory mMapDirectory) { - Directory dir = new HybridDirectory(lockFactory, setPreload(mMapDirectory, lockFactory, preLoadExtensions)); - if (MADV_RANDOM_FEATURE_FLAG.isEnabled() == false) { - dir = disableRandomAdvice(dir); - } - return dir; + return new HybridDirectory(lockFactory, setPreload(mMapDirectory, lockFactory, preLoadExtensions)); } else { return primaryDirectory; } case MMAPFS: - Directory dir = setPreload(new MMapDirectory(location, lockFactory), lockFactory, preLoadExtensions); - if (MADV_RANDOM_FEATURE_FLAG.isEnabled() == false) { - dir = disableRandomAdvice(dir); - } - return dir; + return setPreload(new MMapDirectory(location, lockFactory), lockFactory, preLoadExtensions); case SIMPLEFS: case NIOFS: return new NIOFSDirectory(location, lockFactory); @@ -104,23 +93,6 @@ public static MMapDirectory setPreload(MMapDirectory mMapDirectory, LockFactory return mMapDirectory; } - /** - * Return a {@link FilterDirectory} around the provided {@link Directory} that forcefully disables {@link IOContext#RANDOM random - * access}. - */ - static Directory disableRandomAdvice(Directory dir) { - return new FilterDirectory(dir) { - @Override - public IndexInput openInput(String name, IOContext context) throws IOException { - if (context.randomAccess) { - context = IOContext.READ; - } - assert context.randomAccess == false; - return super.openInput(name, context); - } - }; - } - /** * Returns true iff the directory is a hybrid fs directory */ diff --git a/server/src/main/java/org/elasticsearch/index/store/Store.java b/server/src/main/java/org/elasticsearch/index/store/Store.java index 5a33084e3ea83..b9c50edf50216 100644 --- a/server/src/main/java/org/elasticsearch/index/store/Store.java +++ b/server/src/main/java/org/elasticsearch/index/store/Store.java @@ -1529,7 +1529,7 @@ public Optional findSafeIndexCommit(long globalCheck final IndexCommit safeCommit = CombinedDeletionPolicy.findSafeCommitPoint(commits, globalCheckpoint); final SequenceNumbers.CommitInfo commitInfo = SequenceNumbers.loadSeqNoInfoFromLuceneCommit(safeCommit.getUserData().entrySet()); // all operations of the safe commit must be at most the global checkpoint. - if (commitInfo.maxSeqNo <= globalCheckpoint) { + if (commitInfo.maxSeqNo() <= globalCheckpoint) { return Optional.of(commitInfo); } else { return Optional.empty(); diff --git a/server/src/main/java/org/elasticsearch/index/translog/BaseTranslogReader.java b/server/src/main/java/org/elasticsearch/index/translog/BaseTranslogReader.java index d2c862bbf35d7..3be2532e3c3aa 100644 --- a/server/src/main/java/org/elasticsearch/index/translog/BaseTranslogReader.java +++ b/server/src/main/java/org/elasticsearch/index/translog/BaseTranslogReader.java @@ -149,8 +149,8 @@ public long getLastModifiedTime() throws IOException { * Reads a single operation from the given location. */ Translog.Operation read(Translog.Location location) throws IOException { - assert location.generation == this.generation : "generation mismatch expected: " + generation + " got: " + location.generation; - ByteBuffer buffer = ByteBuffer.allocate(location.size); - return read(checksummedStream(buffer, location.translogLocation, location.size, null)); + assert location.generation() == this.generation : "generation mismatch expected: " + generation + " got: " + location.generation(); + ByteBuffer buffer = ByteBuffer.allocate(location.size()); + return read(checksummedStream(buffer, location.translogLocation(), location.size(), null)); } } diff --git a/server/src/main/java/org/elasticsearch/index/translog/Translog.java b/server/src/main/java/org/elasticsearch/index/translog/Translog.java index a079a852021bd..c02a810ed4952 100644 --- a/server/src/main/java/org/elasticsearch/index/translog/Translog.java +++ b/server/src/main/java/org/elasticsearch/index/translog/Translog.java @@ -964,20 +964,10 @@ public TranslogDeletionPolicy getDeletionPolicy() { return deletionPolicy; } - public static class Location implements Comparable { + public record Location(long generation, long translogLocation, int size) implements Comparable { public static Location EMPTY = new Location(0, 0, 0); - public final long generation; - public final long translogLocation; - public final int size; - - public Location(long generation, long translogLocation, int size) { - this.generation = generation; - this.translogLocation = translogLocation; - this.size = size; - } - @Override public String toString() { return "[generation: " + generation + ", location: " + translogLocation + ", size: " + size + "]"; @@ -985,38 +975,10 @@ public String toString() { @Override public int compareTo(Location o) { - if (generation == o.generation) { - return Long.compare(translogLocation, o.translogLocation); - } - return Long.compare(generation, o.generation); - } - - @Override - public boolean equals(Object o) { - if (this == o) { - return true; - } - if (o == null || getClass() != o.getClass()) { - return false; - } - - Location location = (Location) o; - - if (generation != location.generation) { - return false; + int result = Long.compare(generation, o.generation); + if (result == 0) { + result = Long.compare(translogLocation, o.translogLocation); } - if (translogLocation != location.translogLocation) { - return false; - } - return size == location.size; - - } - - @Override - public int hashCode() { - int result = Long.hashCode(generation); - result = 31 * result + Long.hashCode(translogLocation); - result = 31 * result + size; return result; } } @@ -1819,16 +1781,7 @@ void closeFilesIfNoPendingRetentionLocks() throws IOException { /** * References a transaction log generation */ - public static final class TranslogGeneration { - public final String translogUUID; - public final long translogFileGeneration; - - public TranslogGeneration(String translogUUID, long translogFileGeneration) { - this.translogUUID = translogUUID; - this.translogFileGeneration = translogFileGeneration; - } - - } + public record TranslogGeneration(String translogUUID, long translogFileGeneration) {} /** * Returns the current generation of this translog. This corresponds to the latest uncommitted translog generation diff --git a/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java b/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java index 538cfdabef324..df2a9d16ebd6a 100644 --- a/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java +++ b/server/src/main/java/org/elasticsearch/indices/recovery/RecoverySourceHandler.java @@ -1052,7 +1052,7 @@ boolean hasSameLegacySyncId(Store.MetadataSnapshot source, Store.MetadataSnapsho } SequenceNumbers.CommitInfo sourceSeqNos = SequenceNumbers.loadSeqNoInfoFromLuceneCommit(source.commitUserData().entrySet()); SequenceNumbers.CommitInfo targetSeqNos = SequenceNumbers.loadSeqNoInfoFromLuceneCommit(target.commitUserData().entrySet()); - if (sourceSeqNos.localCheckpoint != targetSeqNos.localCheckpoint || targetSeqNos.maxSeqNo != sourceSeqNos.maxSeqNo) { + if (sourceSeqNos.localCheckpoint() != targetSeqNos.localCheckpoint() || targetSeqNos.maxSeqNo() != sourceSeqNos.maxSeqNo()) { final String message = "try to recover " + request.shardId() + " with sync id but " diff --git a/server/src/main/java/org/elasticsearch/repositories/ShardGenerations.java b/server/src/main/java/org/elasticsearch/repositories/ShardGenerations.java index 4c34f2e192a26..0dcb28278a66d 100644 --- a/server/src/main/java/org/elasticsearch/repositories/ShardGenerations.java +++ b/server/src/main/java/org/elasticsearch/repositories/ShardGenerations.java @@ -8,6 +8,8 @@ package org.elasticsearch.repositories; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.cluster.SnapshotsInProgress; import org.elasticsearch.common.Strings; import org.elasticsearch.core.Nullable; @@ -30,6 +32,8 @@ */ public final class ShardGenerations { + private static final Logger logger = LogManager.getLogger(ShardGenerations.class); + public static final ShardGenerations EMPTY = new ShardGenerations(Collections.emptyMap()); /** @@ -88,7 +92,7 @@ public Collection indices() { } /** - * Computes the obsolete shard index generations that can be deleted once this instance was written to the repository. + * Computes the obsolete shard index generations that can be deleted once this instance is written to the repository. * Note: This method should only be used when finalizing a snapshot and we can safely assume that data has only been added but not * removed from shard paths. * @@ -109,6 +113,13 @@ public Map> obsoleteShardGenerations(Shar // Since this method assumes only additions and no removals of shards, a null updated generation means no update if (updatedGeneration != null && oldGeneration != null && oldGeneration.equals(updatedGeneration) == false) { obsoleteShardIndices.put(i, oldGeneration); + logger.debug( + "Marking snapshot generation [{}] for cleanup. The new generation is [{}]. Index [{}], shard ID [{}]", + oldGeneration, + updatedGeneration, + indexId, + i + ); } } result.put(indexId, Collections.unmodifiableMap(obsoleteShardIndices)); diff --git a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java index 8f55bf16c1674..5b7a11969973d 100644 --- a/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java +++ b/server/src/main/java/org/elasticsearch/repositories/blobstore/BlobStoreRepository.java @@ -3946,4 +3946,16 @@ public boolean hasAtomicOverwrites() { public int getReadBufferSizeInBytes() { return bufferSize; } + + /** + * @return extra information to be included in the exception message emitted on failure of a repository analysis. + */ + public String getAnalysisFailureExtraDetail() { + return Strings.format( + """ + Elasticsearch observed the storage system underneath this repository behaved incorrectly which indicates it is not \ + suitable for use with Elasticsearch snapshots. See [%s] for further information.""", + ReferenceDocs.SNAPSHOT_REPOSITORY_ANALYSIS + ); + } } diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java index f2ff8fbccd2fb..e5c2d6a370f12 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteBinaryDenseVector.java @@ -102,7 +102,7 @@ public double l1Norm(List queryVector) { @Override public int hamming(byte[] queryVector) { - return VectorUtil.xorBitCount(queryVector, vectorValue); + return ESVectorUtil.xorBitCount(queryVector, vectorValue); } @Override diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java index e0ba032826aa1..0145eb3eae04b 100644 --- a/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ByteKnnDenseVector.java @@ -103,7 +103,7 @@ public double l1Norm(List queryVector) { @Override public int hamming(byte[] queryVector) { - return VectorUtil.xorBitCount(queryVector, docVector); + return ESVectorUtil.xorBitCount(queryVector, docVector); } @Override diff --git a/server/src/main/java/org/elasticsearch/script/field/vectors/ESVectorUtil.java b/server/src/main/java/org/elasticsearch/script/field/vectors/ESVectorUtil.java new file mode 100644 index 0000000000000..7d9542bccf357 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/script/field/vectors/ESVectorUtil.java @@ -0,0 +1,72 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ + +package org.elasticsearch.script.field.vectors; + +import org.apache.lucene.util.BitUtil; +import org.apache.lucene.util.Constants; + +/** + * This class consists of a single utility method that provides XOR bit count computed over signed bytes. + * Remove this class when Lucene version > 9.11 is released, and replace with Lucene's VectorUtil directly. + */ +public class ESVectorUtil { + + /** + * For xorBitCount we stride over the values as either 64-bits (long) or 32-bits (int) at a time. + * On ARM Long::bitCount is not vectorized, and therefore produces less than optimal code, when + * compared to Integer::bitCount. While Long::bitCount is optimal on x64. + */ + static final boolean XOR_BIT_COUNT_STRIDE_AS_INT = Constants.OS_ARCH.equals("aarch64"); + + /** + * XOR bit count computed over signed bytes. + * + * @param a bytes containing a vector + * @param b bytes containing another vector, of the same dimension + * @return the value of the XOR bit count of the two vectors + */ + public static int xorBitCount(byte[] a, byte[] b) { + if (a.length != b.length) { + throw new IllegalArgumentException("vector dimensions differ: " + a.length + "!=" + b.length); + } + if (XOR_BIT_COUNT_STRIDE_AS_INT) { + return xorBitCountInt(a, b); + } else { + return xorBitCountLong(a, b); + } + } + + /** XOR bit count striding over 4 bytes at a time. */ + static int xorBitCountInt(byte[] a, byte[] b) { + int distance = 0, i = 0; + for (final int upperBound = a.length & -Integer.BYTES; i < upperBound; i += Integer.BYTES) { + distance += Integer.bitCount((int) BitUtil.VH_NATIVE_INT.get(a, i) ^ (int) BitUtil.VH_NATIVE_INT.get(b, i)); + } + // tail: + for (; i < a.length; i++) { + distance += Integer.bitCount((a[i] ^ b[i]) & 0xFF); + } + return distance; + } + + /** XOR bit count striding over 8 bytes at a time. */ + static int xorBitCountLong(byte[] a, byte[] b) { + int distance = 0, i = 0; + for (final int upperBound = a.length & -Long.BYTES; i < upperBound; i += Long.BYTES) { + distance += Long.bitCount((long) BitUtil.VH_NATIVE_LONG.get(a, i) ^ (long) BitUtil.VH_NATIVE_LONG.get(b, i)); + } + // tail: + for (; i < a.length; i++) { + distance += Integer.bitCount((a[i] ^ b[i]) & 0xFF); + } + return distance; + } + + private ESVectorUtil() {} +} diff --git a/server/src/main/java/org/elasticsearch/search/SearchService.java b/server/src/main/java/org/elasticsearch/search/SearchService.java index e26167c589eed..e396f0aa68cf2 100644 --- a/server/src/main/java/org/elasticsearch/search/SearchService.java +++ b/server/src/main/java/org/elasticsearch/search/SearchService.java @@ -1845,7 +1845,13 @@ public AggregationReduceContext.Builder aggReduceContextBuilder(Supplier isCanceled, - AggregatorFactories.Builder builders + AggregatorFactories.Builder builders, + IntConsumer multiBucketConsumer ) { super(bigArrays, scriptService, isCanceled, builders); + this.multiBucketConsumer = multiBucketConsumer; } - public ForPartial(BigArrays bigArrays, ScriptService scriptService, Supplier isCanceled, AggregationBuilder builder) { + public ForPartial( + BigArrays bigArrays, + ScriptService scriptService, + Supplier isCanceled, + AggregationBuilder builder, + IntConsumer multiBucketConsumer + ) { super(bigArrays, scriptService, isCanceled, builder); + this.multiBucketConsumer = multiBucketConsumer; } @Override @@ -158,7 +169,9 @@ public boolean isFinalReduce() { } @Override - protected void consumeBucketCountAndMaybeBreak(int size) {} + protected void consumeBucketCountAndMaybeBreak(int size) { + multiBucketConsumer.accept(size); + } @Override public PipelineTree pipelineTreeRoot() { @@ -167,7 +180,7 @@ public PipelineTree pipelineTreeRoot() { @Override protected AggregationReduceContext forSubAgg(AggregationBuilder sub) { - return new ForPartial(bigArrays(), scriptService(), isCanceled(), sub); + return new ForPartial(bigArrays(), scriptService(), isCanceled(), sub, multiBucketConsumer); } } diff --git a/server/src/main/java/org/elasticsearch/search/aggregations/MultiBucketConsumerService.java b/server/src/main/java/org/elasticsearch/search/aggregations/MultiBucketConsumerService.java index c876f971a7c65..a6f634ec371b1 100644 --- a/server/src/main/java/org/elasticsearch/search/aggregations/MultiBucketConsumerService.java +++ b/server/src/main/java/org/elasticsearch/search/aggregations/MultiBucketConsumerService.java @@ -134,10 +134,37 @@ public int getCount() { } } - public MultiBucketConsumer create() { + /** + * Similar to {@link MultiBucketConsumer} but it only checks the parent circuit breaker every 1024 calls. + * It provides protection for OOM during partial reductions. + */ + private static class MultiBucketConsumerPartialReduction implements IntConsumer { + private final CircuitBreaker breaker; + + // aggregations execute in a single thread so no atomic here + private int callCount = 0; + + private MultiBucketConsumerPartialReduction(CircuitBreaker breaker) { + this.breaker = breaker; + } + + @Override + public void accept(int value) { + // check parent circuit breaker every 1024 calls + if ((++callCount & 0x3FF) == 0) { + breaker.addEstimateBytesAndMaybeBreak(0, "allocated_buckets"); + } + } + } + + public IntConsumer createForFinal() { return new MultiBucketConsumer(maxBucket, breaker); } + public IntConsumer createForPartial() { + return new MultiBucketConsumerPartialReduction(breaker); + } + public int getLimit() { return maxBucket; } diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java index 601c55293418d..348a65d0c4960 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnSearchBuilder.java @@ -298,6 +298,10 @@ public String getField() { return field; } + public List getFilterQueries() { + return filterQueries; + } + public KnnSearchBuilder addFilterQuery(QueryBuilder filterQuery) { Objects.requireNonNull(filterQuery); this.filterQueries.add(filterQuery); diff --git a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java index 0f64859e877f4..f1b1c24c50788 100644 --- a/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java +++ b/server/src/main/java/org/elasticsearch/search/vectors/KnnVectorQueryBuilder.java @@ -10,6 +10,7 @@ import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.apache.lucene.search.join.BitSetProducer; import org.apache.lucene.search.join.ToChildBlockJoinQuery; @@ -436,7 +437,7 @@ protected Query doToQuery(SearchExecutionContext context) throws IOException { ? Math.round(Math.min(NUM_CANDS_MULTIPLICATIVE_FACTOR * requestSize, NUM_CANDS_LIMIT)) : numCands; if (fieldType == null) { - throw new IllegalArgumentException("field [" + fieldName + "] does not exist in the mapping"); + return new MatchNoDocsQuery(); } if (fieldType instanceof DenseVectorFieldType == false) { diff --git a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java index 7b3a83dfc9bb3..1529ef556037a 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/SnapshotShardsService.java @@ -241,7 +241,7 @@ private void handleUpdatedSnapshotsInProgressEntry(String localNodeId, boolean r } if (removingLocalNode) { - pauseShardSnapshots(localNodeId, entry); + pauseShardSnapshotsForNodeRemoval(localNodeId, entry); } else { startNewShardSnapshots(localNodeId, entry); } @@ -318,7 +318,7 @@ private void startNewShardSnapshots(String localNodeId, SnapshotsInProgress.Entr threadPool.executor(ThreadPool.Names.SNAPSHOT).execute(() -> shardSnapshotTasks.forEach(Runnable::run)); } - private void pauseShardSnapshots(String localNodeId, SnapshotsInProgress.Entry entry) { + private void pauseShardSnapshotsForNodeRemoval(String localNodeId, SnapshotsInProgress.Entry entry) { final var localShardSnapshots = shardSnapshots.getOrDefault(entry.snapshot(), Map.of()); for (final Map.Entry shardEntry : entry.shards().entrySet()) { @@ -545,8 +545,8 @@ private String description() { public static String getShardStateId(IndexShard indexShard, IndexCommit snapshotIndexCommit) throws IOException { final Map userCommitData = snapshotIndexCommit.getUserData(); final SequenceNumbers.CommitInfo seqNumInfo = SequenceNumbers.loadSeqNoInfoFromLuceneCommit(userCommitData.entrySet()); - final long maxSeqNo = seqNumInfo.maxSeqNo; - if (maxSeqNo != seqNumInfo.localCheckpoint || maxSeqNo != indexShard.getLastSyncedGlobalCheckpoint()) { + final long maxSeqNo = seqNumInfo.maxSeqNo(); + if (maxSeqNo != seqNumInfo.localCheckpoint() || maxSeqNo != indexShard.getLastSyncedGlobalCheckpoint()) { return null; } return userCommitData.get(Engine.HISTORY_UUID_KEY) @@ -606,8 +606,9 @@ private void syncShardStatsOnNewMaster(List entries) } else if (stage == Stage.PAUSED) { // but we think the shard has paused - we need to make new master know that logger.debug(""" - [{}] new master thinks the shard [{}] is still running but the shard paused locally, updating status on \ - master""", snapshot.snapshot(), shardId); + new master thinks that shard [{}] snapshot [{}], with shard generation [{}], is still running, but the \ + shard snapshot is paused locally, updating status on master + """, shardId, snapshot.snapshot(), localShard.getValue().generation()); notifyUnsuccessfulSnapshotShard( snapshot.snapshot(), shardId, @@ -648,6 +649,14 @@ private void notifyUnsuccessfulSnapshotShard( shardId, new ShardSnapshotStatus(clusterService.localNode().getId(), shardState, generation, failure) ); + if (shardState == ShardState.PAUSED_FOR_NODE_REMOVAL) { + logger.debug( + "Pausing shard [{}] snapshot [{}], with shard generation [{}], because this node is marked for removal", + shardId, + snapshot, + generation + ); + } } /** Updates the shard snapshot status by sending a {@link UpdateIndexShardSnapshotStatusRequest} to the master node */ diff --git a/server/src/main/java/org/elasticsearch/snapshots/SnapshotsService.java b/server/src/main/java/org/elasticsearch/snapshots/SnapshotsService.java index cd7516a8f1232..9178050ff2a0b 100644 --- a/server/src/main/java/org/elasticsearch/snapshots/SnapshotsService.java +++ b/server/src/main/java/org/elasticsearch/snapshots/SnapshotsService.java @@ -999,39 +999,42 @@ public ClusterState execute(ClusterState currentState) { // We keep a cache of shards that failed in this map. If we fail a shardId for a given repository because of // a node leaving or shard becoming unassigned for one snapshot, we will also fail it for all subsequent enqueued // snapshots for the same repository + // // TODO: the code in this state update duplicates large chunks of the logic in #SHARD_STATE_EXECUTOR. // We should refactor it to ideally also go through #SHARD_STATE_EXECUTOR by hand-crafting shard state updates // that encapsulate nodes leaving or indices having been deleted and passing them to the executor instead. - SnapshotsInProgress updated = snapshots; + SnapshotsInProgress updatedSnapshots = snapshots; + for (final List snapshotsInRepo : snapshots.entriesByRepo()) { boolean changed = false; final List updatedEntriesForRepo = new ArrayList<>(); final Map knownFailures = new HashMap<>(); - final String repository = snapshotsInRepo.get(0).repository(); - for (SnapshotsInProgress.Entry snapshot : snapshotsInRepo) { - if (statesToUpdate.contains(snapshot.state())) { - if (snapshot.isClone()) { - if (snapshot.shardsByRepoShardId().isEmpty()) { + final String repositoryName = snapshotsInRepo.get(0).repository(); + for (SnapshotsInProgress.Entry snapshotEntry : snapshotsInRepo) { + if (statesToUpdate.contains(snapshotEntry.state())) { + if (snapshotEntry.isClone()) { + if (snapshotEntry.shardsByRepoShardId().isEmpty()) { // Currently initializing clone - if (initializingClones.contains(snapshot.snapshot())) { - updatedEntriesForRepo.add(snapshot); + if (initializingClones.contains(snapshotEntry.snapshot())) { + updatedEntriesForRepo.add(snapshotEntry); } else { - logger.debug("removing not yet start clone operation [{}]", snapshot); + logger.debug("removing not yet start clone operation [{}]", snapshotEntry); changed = true; } } else { // see if any clones may have had a shard become available for execution because of failures - if (deletes.hasExecutingDeletion(repository)) { + if (deletes.hasExecutingDeletion(repositoryName)) { // Currently executing a delete for this repo, no need to try and update any clone operations. // The logic for finishing the delete will update running clones with the latest changes. - updatedEntriesForRepo.add(snapshot); + updatedEntriesForRepo.add(snapshotEntry); continue; } ImmutableOpenMap.Builder clones = null; InFlightShardSnapshotStates inFlightShardSnapshotStates = null; for (Map.Entry failureEntry : knownFailures.entrySet()) { final RepositoryShardId repositoryShardId = failureEntry.getKey(); - final ShardSnapshotStatus existingStatus = snapshot.shardsByRepoShardId().get(repositoryShardId); + final ShardSnapshotStatus existingStatus = snapshotEntry.shardsByRepoShardId() + .get(repositoryShardId); if (ShardSnapshotStatus.UNASSIGNED_QUEUED.equals(existingStatus)) { if (inFlightShardSnapshotStates == null) { inFlightShardSnapshotStates = InFlightShardSnapshotStates.forEntries(updatedEntriesForRepo); @@ -1044,7 +1047,7 @@ public ClusterState execute(ClusterState currentState) { continue; } if (clones == null) { - clones = ImmutableOpenMap.builder(snapshot.shardsByRepoShardId()); + clones = ImmutableOpenMap.builder(snapshotEntry.shardsByRepoShardId()); } // We can use the generation from the shard failure to start the clone operation here // because #processWaitingShardsAndRemovedNodes adds generations to failure statuses that @@ -1060,50 +1063,54 @@ public ClusterState execute(ClusterState currentState) { } if (clones != null) { changed = true; - updatedEntriesForRepo.add(snapshot.withClones(clones.build())); + updatedEntriesForRepo.add(snapshotEntry.withClones(clones.build())); } else { - updatedEntriesForRepo.add(snapshot); + updatedEntriesForRepo.add(snapshotEntry); } } } else { + // Not a clone, and the snapshot is in STARTED or ABORTED state. + ImmutableOpenMap shards = processWaitingShardsAndRemovedNodes( - snapshot, + snapshotEntry, routingTable, nodes, snapshots::isNodeIdForRemoval, knownFailures ); if (shards != null) { - final SnapshotsInProgress.Entry updatedSnapshot = snapshot.withShardStates(shards); + final SnapshotsInProgress.Entry updatedSnapshot = snapshotEntry.withShardStates(shards); changed = true; if (updatedSnapshot.state().completed()) { finishedSnapshots.add(updatedSnapshot); } updatedEntriesForRepo.add(updatedSnapshot); } else { - updatedEntriesForRepo.add(snapshot); + updatedEntriesForRepo.add(snapshotEntry); } } - } else if (snapshot.repositoryStateId() == RepositoryData.UNKNOWN_REPO_GEN) { + } else if (snapshotEntry.repositoryStateId() == RepositoryData.UNKNOWN_REPO_GEN) { // BwC path, older versions could create entries with unknown repo GEN in INIT or ABORTED state that did not // yet write anything to the repository physically. This means we can simply remove these from the cluster // state without having to do any additional cleanup. changed = true; - logger.debug("[{}] was found in dangling INIT or ABORTED state", snapshot); + logger.debug("[{}] was found in dangling INIT or ABORTED state", snapshotEntry); } else { - if (snapshot.state().completed() || completed(snapshot.shardsByRepoShardId().values())) { - finishedSnapshots.add(snapshot); + // Now we're down to completed or un-modified snapshots + + if (snapshotEntry.state().completed() || completed(snapshotEntry.shardsByRepoShardId().values())) { + finishedSnapshots.add(snapshotEntry); } - updatedEntriesForRepo.add(snapshot); + updatedEntriesForRepo.add(snapshotEntry); } } if (changed) { - updated = updated.withUpdatedEntriesForRepo(repository, updatedEntriesForRepo); + updatedSnapshots = updatedSnapshots.withUpdatedEntriesForRepo(repositoryName, updatedEntriesForRepo); } } final ClusterState res = readyDeletions( - updated != snapshots - ? ClusterState.builder(currentState).putCustom(SnapshotsInProgress.TYPE, updated).build() + updatedSnapshots != snapshots + ? ClusterState.builder(currentState).putCustom(SnapshotsInProgress.TYPE, updatedSnapshots).build() : currentState ).v1(); for (SnapshotDeletionsInProgress.Entry delete : SnapshotDeletionsInProgress.get(res).getEntries()) { @@ -1151,31 +1158,39 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState) }); } + /** + * Walks through the snapshot entries' shard snapshots and creates applies updates from looking at removed nodes or indexes and known + * failed shard snapshots on the same shard IDs. + * + * @param nodeIdRemovalPredicate identify any nodes that are marked for removal / in shutdown mode + * @param knownFailures already known failed shard snapshots, but more may be found in this method + * @return an updated map of shard statuses + */ private static ImmutableOpenMap processWaitingShardsAndRemovedNodes( - SnapshotsInProgress.Entry entry, + SnapshotsInProgress.Entry snapshotEntry, RoutingTable routingTable, DiscoveryNodes nodes, Predicate nodeIdRemovalPredicate, Map knownFailures ) { - assert entry.isClone() == false : "clones take a different path"; + assert snapshotEntry.isClone() == false : "clones take a different path"; boolean snapshotChanged = false; ImmutableOpenMap.Builder shards = ImmutableOpenMap.builder(); - for (Map.Entry shardEntry : entry.shardsByRepoShardId().entrySet()) { - ShardSnapshotStatus shardStatus = shardEntry.getValue(); - ShardId shardId = entry.shardId(shardEntry.getKey()); + for (Map.Entry shardSnapshotEntry : snapshotEntry.shardsByRepoShardId().entrySet()) { + ShardSnapshotStatus shardStatus = shardSnapshotEntry.getValue(); + ShardId shardId = snapshotEntry.shardId(shardSnapshotEntry.getKey()); if (shardStatus.equals(ShardSnapshotStatus.UNASSIGNED_QUEUED)) { // this shard snapshot is waiting for a previous snapshot to finish execution for this shard - final ShardSnapshotStatus knownFailure = knownFailures.get(shardEntry.getKey()); + final ShardSnapshotStatus knownFailure = knownFailures.get(shardSnapshotEntry.getKey()); if (knownFailure == null) { final IndexRoutingTable indexShardRoutingTable = routingTable.index(shardId.getIndex()); if (indexShardRoutingTable == null) { // shard became unassigned while queued after a delete or clone operation so we can fail as missing here - assert entry.partial(); + assert snapshotEntry.partial(); snapshotChanged = true; logger.debug("failing snapshot of shard [{}] because index got deleted", shardId); shards.put(shardId, ShardSnapshotStatus.MISSING); - knownFailures.put(shardEntry.getKey(), ShardSnapshotStatus.MISSING); + knownFailures.put(shardSnapshotEntry.getKey(), ShardSnapshotStatus.MISSING); } else { // if no failure is known for the shard we keep waiting shards.put(shardId, shardStatus); @@ -1187,6 +1202,7 @@ private static ImmutableOpenMap processWaitingShar shards.put(shardId, knownFailure); } } else if (shardStatus.state() == ShardState.WAITING || shardStatus.state() == ShardState.PAUSED_FOR_NODE_REMOVAL) { + // The shard primary wasn't assigned, or the shard snapshot was paused because the node was shutting down. IndexRoutingTable indexShardRoutingTable = routingTable.index(shardId.getIndex()); if (indexShardRoutingTable != null) { IndexShardRoutingTable shardRouting = indexShardRoutingTable.shard(shardId.id()); @@ -1208,7 +1224,10 @@ private static ImmutableOpenMap processWaitingShar } else if (shardRouting.primaryShard().started()) { // Shard that we were waiting for has started on a node, let's process it snapshotChanged = true; - logger.trace("starting shard that we were waiting for [{}] on node [{}]", shardId, shardStatus.nodeId()); + logger.debug(""" + Starting shard [{}] with shard generation [{}] that we were waiting to start on node [{}]. Previous \ + shard state [{}] + """, shardId, shardStatus.generation(), shardStatus.nodeId(), shardStatus.state()); shards.put(shardId, new ShardSnapshotStatus(primaryNodeId, shardStatus.generation())); continue; } else if (shardRouting.primaryShard().initializing() || shardRouting.primaryShard().relocating()) { @@ -1218,7 +1237,7 @@ private static ImmutableOpenMap processWaitingShar } } } - // Shard that we were waiting for went into unassigned state or disappeared - giving up + // Shard that we were waiting for went into unassigned state or disappeared (index or shard is gone) - giving up snapshotChanged = true; logger.warn("failing snapshot of shard [{}] on unassigned shard [{}]", shardId, shardStatus.nodeId()); final ShardSnapshotStatus failedState = new ShardSnapshotStatus( @@ -1228,7 +1247,7 @@ private static ImmutableOpenMap processWaitingShar "shard is unassigned" ); shards.put(shardId, failedState); - knownFailures.put(shardEntry.getKey(), failedState); + knownFailures.put(shardSnapshotEntry.getKey(), failedState); } else if (shardStatus.state().completed() == false && shardStatus.nodeId() != null) { if (nodes.nodeExists(shardStatus.nodeId())) { shards.put(shardId, shardStatus); @@ -1243,7 +1262,7 @@ private static ImmutableOpenMap processWaitingShar "node left the cluster during snapshot" ); shards.put(shardId, failedState); - knownFailures.put(shardEntry.getKey(), failedState); + knownFailures.put(shardSnapshotEntry.getKey(), failedState); } } else { shards.put(shardId, shardStatus); diff --git a/server/src/main/java/org/elasticsearch/transport/TransportService.java b/server/src/main/java/org/elasticsearch/transport/TransportService.java index c3d53855a9c75..33ea35ecffd94 100644 --- a/server/src/main/java/org/elasticsearch/transport/TransportService.java +++ b/server/src/main/java/org/elasticsearch/transport/TransportService.java @@ -17,6 +17,7 @@ import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.node.DiscoveryNode; +import org.elasticsearch.common.ReferenceDocs; import org.elasticsearch.common.Strings; import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.io.stream.RecyclerBytesStreamOutput; @@ -518,7 +519,19 @@ public ConnectionManager.ConnectionValidator connectionValidator(DiscoveryNode n handshake(newConnection, actualProfile.getHandshakeTimeout(), Predicates.always(), listener.map(resp -> { final DiscoveryNode remote = resp.discoveryNode; if (node.equals(remote) == false) { - throw new ConnectTransportException(node, "handshake failed. unexpected remote node " + remote); + throw new ConnectTransportException( + node, + Strings.format( + """ + Connecting to [%s] failed: expected to connect to [%s] but found [%s] instead. Ensure that each node has \ + its own distinct publish address, and that your network is configured so that every connection to a node's \ + publish address is routed to the correct node. See %s for more information.""", + node.getAddress(), + node.descriptionWithoutAttributes(), + remote.descriptionWithoutAttributes(), + ReferenceDocs.NETWORK_BINDING_AND_PUBLISHING + ) + ); } return null; })); diff --git a/server/src/main/resources/org/elasticsearch/common/reference-docs-links.json b/server/src/main/resources/org/elasticsearch/common/reference-docs-links.json index f3e5bd7a375f1..febcaec1ba057 100644 --- a/server/src/main/resources/org/elasticsearch/common/reference-docs-links.json +++ b/server/src/main/resources/org/elasticsearch/common/reference-docs-links.json @@ -33,5 +33,9 @@ "CONTACT_SUPPORT": "troubleshooting.html#troubleshooting-contact-support", "UNASSIGNED_SHARDS": "red-yellow-cluster-status.html", "EXECUTABLE_JNA_TMPDIR": "executable-jna-tmpdir.html", - "NETWORK_THREADING_MODEL": "modules-network.html#modules-network-threading-model" + "NETWORK_THREADING_MODEL": "modules-network.html#modules-network-threading-model", + "ALLOCATION_EXPLAIN_API": "cluster-allocation-explain.html", + "NETWORK_BINDING_AND_PUBLISHING": "modules-network.html#modules-network-binding-publishing", + "SNAPSHOT_REPOSITORY_ANALYSIS": "repo-analysis-api.html", + "S3_COMPATIBLE_REPOSITORIES": "repository-s3.html#repository-s3-compatible-services" } diff --git a/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java b/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java index 0543bce08a4f0..463203c1357b9 100644 --- a/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java +++ b/server/src/test/java/org/elasticsearch/action/ActionListenerTests.java @@ -23,7 +23,6 @@ import java.io.IOException; import java.util.ArrayList; import java.util.List; -import java.util.concurrent.CyclicBarrier; import java.util.concurrent.ExecutionException; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -308,25 +307,13 @@ public String toString() { }); assertThat(listener.toString(), equalTo("notifyOnce[inner-listener]")); - final var threads = new Thread[between(1, 10)]; - final var startBarrier = new CyclicBarrier(threads.length); - for (int i = 0; i < threads.length; i++) { - threads[i] = new Thread(() -> { - safeAwait(startBarrier); - if (randomBoolean()) { - listener.onResponse(null); - } else { - listener.onFailure(new RuntimeException("test")); - } - }); - } - - for (Thread thread : threads) { - thread.start(); - } - for (Thread thread : threads) { - thread.join(); - } + startInParallel(between(1, 10), i -> { + if (randomBoolean()) { + listener.onResponse(null); + } else { + listener.onFailure(new RuntimeException("test")); + } + }); assertTrue(completed.get()); } diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainActionTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainActionTests.java index eb1a64ef66bbd..d78dbae509b63 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplainActionTests.java @@ -188,7 +188,9 @@ public void testFindAnyUnassignedShardToExplain() { allOf( // no point in asserting the precise wording of the message into this test, but we care that it contains these bits: containsString("No shard was specified in the request"), - containsString("specify the target shard in the request") + containsString("specify the target shard in the request"), + containsString("https://www.elastic.co/guide/en/elasticsearch/reference"), + containsString("cluster-allocation-explain.html") ) ); } diff --git a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplanationTests.java b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplanationTests.java index ed81f6750aa27..463446f8b36ed 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplanationTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/cluster/allocation/ClusterAllocationExplanationTests.java @@ -149,7 +149,9 @@ public void testRandomShardExplanationToXContent() throws Exception { allOf( // no point in asserting the precise wording of the message into this test, but we care that the note contains these bits: containsString("No shard was specified in the explain API request"), - containsString("specify the target shard in the request") + containsString("specify the target shard in the request"), + containsString("https://www.elastic.co/guide/en/elasticsearch/reference"), + containsString("cluster-allocation-explain.html") ) ); diff --git a/server/src/test/java/org/elasticsearch/action/admin/indices/close/CloseIndexRequestTests.java b/server/src/test/java/org/elasticsearch/action/admin/indices/close/CloseIndexRequestTests.java index b3caf93fbcddf..24c0f9d97800b 100644 --- a/server/src/test/java/org/elasticsearch/action/admin/indices/close/CloseIndexRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/admin/indices/close/CloseIndexRequestTests.java @@ -49,6 +49,9 @@ public void testBwcSerialization() throws Exception { in.setTransportVersion(out.getTransportVersion()); assertEquals(request.getParentTask(), TaskId.readFromStream(in)); assertEquals(request.masterNodeTimeout(), in.readTimeValue()); + if (in.getTransportVersion().onOrAfter(TransportVersions.VERSIONED_MASTER_NODE_REQUESTS)) { + assertEquals(request.masterTerm(), in.readVLong()); + } assertEquals(request.ackTimeout(), in.readTimeValue()); assertArrayEquals(request.indices(), in.readStringArray()); final IndicesOptions indicesOptions = IndicesOptions.readIndicesOptions(in); @@ -75,6 +78,9 @@ public void testBwcSerialization() throws Exception { out.setTransportVersion(version); sample.getParentTask().writeTo(out); out.writeTimeValue(sample.masterNodeTimeout()); + if (out.getTransportVersion().onOrAfter(TransportVersions.VERSIONED_MASTER_NODE_REQUESTS)) { + out.writeVLong(sample.masterTerm()); + } out.writeTimeValue(sample.ackTimeout()); out.writeStringArray(sample.indices()); sample.indicesOptions().writeIndicesOptions(out); diff --git a/server/src/test/java/org/elasticsearch/action/index/IndexRequestTests.java b/server/src/test/java/org/elasticsearch/action/index/IndexRequestTests.java index 6106dbf1fbc5a..c05cb054ce391 100644 --- a/server/src/test/java/org/elasticsearch/action/index/IndexRequestTests.java +++ b/server/src/test/java/org/elasticsearch/action/index/IndexRequestTests.java @@ -217,22 +217,6 @@ public void testSerializeDynamicTemplates() throws Exception { IndexRequest serialized = new IndexRequest(in); assertThat(serialized.getDynamicTemplates(), anEmptyMap()); } - // old version - { - Map dynamicTemplates = IntStream.range(0, randomIntBetween(1, 10)) - .boxed() - .collect(Collectors.toMap(n -> "field-" + n, n -> "name-" + n)); - indexRequest.setDynamicTemplates(dynamicTemplates); - TransportVersion ver = TransportVersionUtils.randomVersionBetween( - random(), - TransportVersions.V_7_0_0, - TransportVersionUtils.getPreviousVersion(TransportVersions.V_7_13_0) - ); - BytesStreamOutput out = new BytesStreamOutput(); - out.setTransportVersion(ver); - IllegalArgumentException error = expectThrows(IllegalArgumentException.class, () -> indexRequest.writeTo(out)); - assertThat(error.getMessage(), equalTo("[dynamic_templates] parameter requires all nodes on 7.13.0 or later")); - } // new version { Map dynamicTemplates = IntStream.range(0, randomIntBetween(0, 10)) diff --git a/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java b/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java index db32213ff97b7..ab7d9f180eae4 100644 --- a/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/QueryPhaseResultConsumerTests.java @@ -53,7 +53,13 @@ public void setup() { searchPhaseController = new SearchPhaseController((t, s) -> new AggregationReduceContext.Builder() { @Override public AggregationReduceContext forPartialReduction() { - return new AggregationReduceContext.ForPartial(BigArrays.NON_RECYCLING_INSTANCE, null, t, mock(AggregationBuilder.class)); + return new AggregationReduceContext.ForPartial( + BigArrays.NON_RECYCLING_INSTANCE, + null, + t, + mock(AggregationBuilder.class), + b -> {} + ); } public AggregationReduceContext forFinalReduction() { diff --git a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java index 43bca4bae2f3f..118a7055cd782 100644 --- a/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java +++ b/server/src/test/java/org/elasticsearch/action/search/SearchPhaseControllerTests.java @@ -122,7 +122,7 @@ public void setup() { @Override public AggregationReduceContext forPartialReduction() { reductions.add(false); - return new AggregationReduceContext.ForPartial(BigArrays.NON_RECYCLING_INSTANCE, null, t, agg); + return new AggregationReduceContext.ForPartial(BigArrays.NON_RECYCLING_INSTANCE, null, t, agg, b -> {}); } public AggregationReduceContext forFinalReduction() { diff --git a/server/src/test/java/org/elasticsearch/action/support/replication/TransportWriteActionTests.java b/server/src/test/java/org/elasticsearch/action/support/replication/TransportWriteActionTests.java index 5530ec61fea33..340ca87968db0 100644 --- a/server/src/test/java/org/elasticsearch/action/support/replication/TransportWriteActionTests.java +++ b/server/src/test/java/org/elasticsearch/action/support/replication/TransportWriteActionTests.java @@ -92,7 +92,6 @@ public class TransportWriteActionTests extends ESTestCase { private ClusterService clusterService; private IndexShard indexShard; - private Translog.Location location; @BeforeClass public static void beforeClass() { @@ -102,7 +101,6 @@ public static void beforeClass() { @Before public void initCommonMocks() { indexShard = mock(IndexShard.class); - location = mock(Translog.Location.class); clusterService = createClusterService(threadPool); when(indexShard.refresh(any())).thenReturn(new Engine.RefreshResult(true, randomNonNegativeLong(), 1)); ReplicationGroup replicationGroup = mock(ReplicationGroup.class); @@ -483,7 +481,14 @@ protected void dispatchedShardOperationOnPrimary( if (withDocumentFailureOnPrimary) { throw new RuntimeException("simulated"); } else { - return new WritePrimaryResult<>(request, new TestResponse(), location, primary, logger, postWriteRefresh); + return new WritePrimaryResult<>( + request, + new TestResponse(), + Translog.Location.EMPTY, + primary, + logger, + postWriteRefresh + ); } }); } @@ -495,7 +500,7 @@ protected void dispatchedShardOperationOnReplica(TestRequest request, IndexShard if (withDocumentFailureOnReplica) { replicaResult = new WriteReplicaResult<>(request, null, new RuntimeException("simulated"), replica, logger); } else { - replicaResult = new WriteReplicaResult<>(request, location, null, replica, logger); + replicaResult = new WriteReplicaResult<>(request, Translog.Location.EMPTY, null, replica, logger); } return replicaResult; }); diff --git a/server/src/test/java/org/elasticsearch/common/util/BitArrayTests.java b/server/src/test/java/org/elasticsearch/common/util/BitArrayTests.java index f81a4bd2f4a18..e3f2522de4813 100644 --- a/server/src/test/java/org/elasticsearch/common/util/BitArrayTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/BitArrayTests.java @@ -51,6 +51,27 @@ public void testRandom() { } } + public void testRandomSetValue() { + try (BitArray bitArray = new BitArray(1, BigArrays.NON_RECYCLING_INSTANCE)) { + int numBits = randomIntBetween(1000, 10000); + for (int step = 0; step < 3; step++) { + boolean[] bits = new boolean[numBits]; + List slots = new ArrayList<>(); + for (int i = 0; i < numBits; i++) { + bits[i] = randomBoolean(); + slots.add(i); + } + Collections.shuffle(slots, random()); + for (int i : slots) { + bitArray.set(i, bits[i]); + } + for (int i = 0; i < numBits; i++) { + assertEquals(bitArray.get(i), bits[i]); + } + } + } + } + public void testVeryLarge() { assumeThat(Runtime.getRuntime().maxMemory(), greaterThanOrEqualTo(ByteSizeUnit.MB.toBytes(512))); try (BitArray bitArray = new BitArray(1, BigArrays.NON_RECYCLING_INSTANCE)) { @@ -183,6 +204,78 @@ public void testGetAndSet() { } } + public void testFillTrueRandom() { + try (BitArray bitArray = new BitArray(1, BigArrays.NON_RECYCLING_INSTANCE)) { + int from = randomIntBetween(0, 1000); + int to = randomIntBetween(from, 1000); + + bitArray.fill(0, 1000, false); + bitArray.fill(from, to, true); + + for (int i = 0; i < 1000; i++) { + if (i < from || i >= to) { + assertFalse(bitArray.get(i)); + } else { + assertTrue(bitArray.get(i)); + } + } + } + } + + public void testFillFalseRandom() { + try (BitArray bitArray = new BitArray(1, BigArrays.NON_RECYCLING_INSTANCE)) { + int from = randomIntBetween(0, 1000); + int to = randomIntBetween(from, 1000); + + bitArray.fill(0, 1000, true); + bitArray.fill(from, to, false); + + for (int i = 0; i < 1000; i++) { + if (i < from || i >= to) { + assertTrue(bitArray.get(i)); + } else { + assertFalse(bitArray.get(i)); + } + } + } + } + + public void testFillTrueSingleWord() { + try (BitArray bitArray = new BitArray(1, BigArrays.NON_RECYCLING_INSTANCE)) { + int from = 8; + int to = 56; + + bitArray.fill(0, 64, false); + bitArray.fill(from, to, true); + + for (int i = 0; i < 64; i++) { + if (i < from || i >= to) { + assertFalse(bitArray.get(i)); + } else { + assertTrue(bitArray.get(i)); + } + } + } + } + + public void testFillFalseSingleWord() { + try (BitArray bitArray = new BitArray(1, BigArrays.NON_RECYCLING_INSTANCE)) { + int from = 8; + int to = 56; + + bitArray.fill(0, 64, true); + bitArray.fill(from, to, false); + + for (int i = 0; i < 64; i++) { + if (i < from || i >= to) { + assertTrue(bitArray.get(i)); + } else { + assertFalse(bitArray.get(i)); + } + } + } + } + public void testSerialize() throws Exception { int initial = randomIntBetween(1, 100_000); BitArray bits1 = new BitArray(initial, BigArrays.NON_RECYCLING_INSTANCE); diff --git a/server/src/test/java/org/elasticsearch/common/util/concurrent/AsyncIOProcessorTests.java b/server/src/test/java/org/elasticsearch/common/util/concurrent/AsyncIOProcessorTests.java index 65bcb473f7d22..0392a3f5ab4e1 100644 --- a/server/src/test/java/org/elasticsearch/common/util/concurrent/AsyncIOProcessorTests.java +++ b/server/src/test/java/org/elasticsearch/common/util/concurrent/AsyncIOProcessorTests.java @@ -54,32 +54,19 @@ protected void write(List>> candidates) throws }; Semaphore semaphore = new Semaphore(Integer.MAX_VALUE); final int count = randomIntBetween(1000, 20000); - Thread[] thread = new Thread[randomIntBetween(3, 10)]; - CountDownLatch latch = new CountDownLatch(thread.length); - for (int i = 0; i < thread.length; i++) { - thread[i] = new Thread() { - @Override - public void run() { - try { - latch.countDown(); - latch.await(); - for (int i = 0; i < count; i++) { - semaphore.acquire(); - processor.put(new Object(), (ex) -> semaphore.release()); - } - } catch (Exception ex) { - throw new RuntimeException(ex); - } + final int threads = randomIntBetween(3, 10); + startInParallel(threads, t -> { + for (int i = 0; i < count; i++) { + try { + semaphore.acquire(); + processor.put(new Object(), (ex) -> semaphore.release()); + } catch (Exception ex) { + throw new RuntimeException(ex); } - }; - thread[i].start(); - } - - for (int i = 0; i < thread.length; i++) { - thread[i].join(); - } + } + }); safeAcquire(10, semaphore); - assertEquals(count * thread.length, received.get()); + assertEquals(count * threads, received.get()); } public void testRandomFail() throws InterruptedException { @@ -102,37 +89,24 @@ protected void write(List>> candidates) throws }; Semaphore semaphore = new Semaphore(Integer.MAX_VALUE); final int count = randomIntBetween(1000, 20000); - Thread[] thread = new Thread[randomIntBetween(3, 10)]; - CountDownLatch latch = new CountDownLatch(thread.length); - for (int i = 0; i < thread.length; i++) { - thread[i] = new Thread() { - @Override - public void run() { - try { - latch.countDown(); - latch.await(); - for (int i = 0; i < count; i++) { - semaphore.acquire(); - processor.put(new Object(), (ex) -> { - if (ex != null) { - actualFailed.incrementAndGet(); - } - semaphore.release(); - }); + final int threads = randomIntBetween(3, 10); + startInParallel(threads, t -> { + try { + for (int i = 0; i < count; i++) { + semaphore.acquire(); + processor.put(new Object(), (ex) -> { + if (ex != null) { + actualFailed.incrementAndGet(); } - } catch (Exception ex) { - throw new RuntimeException(ex); - } + semaphore.release(); + }); } - }; - thread[i].start(); - } - - for (int i = 0; i < thread.length; i++) { - thread[i].join(); - } + } catch (Exception ex) { + throw new RuntimeException(ex); + } + }); safeAcquire(Integer.MAX_VALUE, semaphore); - assertEquals(count * thread.length, received.get()); + assertEquals(count * threads, received.get()); assertEquals(actualFailed.get(), failed.get()); } @@ -226,7 +200,7 @@ public void run() { threads.forEach(t -> assertFalse(t.isAlive())); } - public void testSlowConsumer() { + public void testSlowConsumer() throws InterruptedException { AtomicInteger received = new AtomicInteger(0); AtomicInteger notified = new AtomicInteger(0); @@ -240,39 +214,23 @@ protected void write(List>> candidates) throws int threadCount = randomIntBetween(2, 10); CyclicBarrier barrier = new CyclicBarrier(threadCount); Semaphore serializePutSemaphore = new Semaphore(1); - List threads = IntStream.range(0, threadCount).mapToObj(i -> new Thread(getTestName() + "_" + i) { - { - setDaemon(true); - } - - @Override - public void run() { - try { - assertTrue(serializePutSemaphore.tryAcquire(10, TimeUnit.SECONDS)); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - processor.put(new Object(), (e) -> { - serializePutSemaphore.release(); - try { - barrier.await(10, TimeUnit.SECONDS); - } catch (InterruptedException | BrokenBarrierException | TimeoutException ex) { - throw new RuntimeException(ex); - } - notified.incrementAndGet(); - }); - } - }).toList(); - threads.forEach(Thread::start); - threads.forEach(t -> { + runInParallel(threadCount, t -> { try { - t.join(20000); + assertTrue(serializePutSemaphore.tryAcquire(10, TimeUnit.SECONDS)); } catch (InterruptedException e) { throw new RuntimeException(e); } + processor.put(new Object(), (e) -> { + serializePutSemaphore.release(); + try { + barrier.await(10, TimeUnit.SECONDS); + } catch (InterruptedException | BrokenBarrierException | TimeoutException ex) { + throw new RuntimeException(ex); + } + notified.incrementAndGet(); + }); }); assertEquals(threadCount, notified.get()); assertEquals(threadCount, received.get()); - threads.forEach(t -> assertFalse(t.isAlive())); } } diff --git a/server/src/test/java/org/elasticsearch/discovery/HandshakingTransportAddressConnectorTests.java b/server/src/test/java/org/elasticsearch/discovery/HandshakingTransportAddressConnectorTests.java index 8ca96aff9c3e5..5c6afc1e805ce 100644 --- a/server/src/test/java/org/elasticsearch/discovery/HandshakingTransportAddressConnectorTests.java +++ b/server/src/test/java/org/elasticsearch/discovery/HandshakingTransportAddressConnectorTests.java @@ -18,6 +18,8 @@ import org.elasticsearch.cluster.ClusterName; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; +import org.elasticsearch.common.ReferenceDocs; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.transport.TransportAddress; import org.elasticsearch.core.Nullable; @@ -159,13 +161,16 @@ public void testLogsFullConnectionFailureAfterSuccessfulHandshake() throws Excep "message", HandshakingTransportAddressConnector.class.getCanonicalName(), Level.WARN, - "completed handshake with [" - + remoteNode.descriptionWithoutAttributes() - + "] at [" - + discoveryAddress - + "] but followup connection to [" - + remoteNodeAddress - + "] failed" + Strings.format( + """ + Successfully discovered master-eligible node [%s] at address [%s] but could not connect to it at its publish \ + address of [%s]. Each node in a cluster must be accessible at its publish address by all other nodes in the \ + cluster. See %s for more information.""", + remoteNode.descriptionWithoutAttributes(), + discoveryAddress, + remoteNodeAddress, + ReferenceDocs.NETWORK_BINDING_AND_PUBLISHING + ) ) ); diff --git a/server/src/test/java/org/elasticsearch/index/engine/FlushListenersTests.java b/server/src/test/java/org/elasticsearch/index/engine/FlushListenersTests.java index 9c345eb923ab4..bff978f8e79d8 100644 --- a/server/src/test/java/org/elasticsearch/index/engine/FlushListenersTests.java +++ b/server/src/test/java/org/elasticsearch/index/engine/FlushListenersTests.java @@ -29,8 +29,8 @@ public void testFlushListenerCompletedImmediatelyIfFlushAlreadyOccurred() { ); flushListeners.afterFlush(generation, lastWriteLocation); Translog.Location waitLocation = new Translog.Location( - lastWriteLocation.generation - randomLongBetween(0, 2), - lastWriteLocation.generation - randomLongBetween(10, 90), + lastWriteLocation.generation() - randomLongBetween(0, 2), + lastWriteLocation.generation() - randomLongBetween(10, 90), 2 ); PlainActionFuture future = new PlainActionFuture<>(); @@ -48,8 +48,8 @@ public void testFlushListenerCompletedAfterLocationFlushed() { Integer.MAX_VALUE ); Translog.Location waitLocation = new Translog.Location( - lastWriteLocation.generation - randomLongBetween(0, 2), - lastWriteLocation.generation - randomLongBetween(10, 90), + lastWriteLocation.generation() - randomLongBetween(0, 2), + lastWriteLocation.generation() - randomLongBetween(10, 90), 2 ); PlainActionFuture future = new PlainActionFuture<>(); @@ -61,13 +61,13 @@ public void testFlushListenerCompletedAfterLocationFlushed() { long generation2 = generation + 1; Translog.Location secondLastWriteLocation = new Translog.Location( - lastWriteLocation.generation, - lastWriteLocation.translogLocation + 10, + lastWriteLocation.generation(), + lastWriteLocation.translogLocation() + 10, Integer.MAX_VALUE ); Translog.Location waitLocation2 = new Translog.Location( - lastWriteLocation.generation, - lastWriteLocation.translogLocation + 4, + lastWriteLocation.generation(), + lastWriteLocation.translogLocation() + 4, 2 ); diff --git a/server/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java b/server/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java index a89ac5bc5b74e..c668cfbb502a2 100644 --- a/server/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java +++ b/server/src/test/java/org/elasticsearch/index/engine/InternalEngineTests.java @@ -1249,7 +1249,7 @@ public void testSyncTranslogConcurrently() throws Exception { SequenceNumbers.CommitInfo commitInfo = SequenceNumbers.loadSeqNoInfoFromLuceneCommit( safeCommit.getIndexCommit().getUserData().entrySet() ); - assertThat(commitInfo.localCheckpoint, equalTo(engine.getProcessedLocalCheckpoint())); + assertThat(commitInfo.localCheckpoint(), equalTo(engine.getProcessedLocalCheckpoint())); } }; final Thread[] threads = new Thread[randomIntBetween(2, 4)]; @@ -3414,7 +3414,7 @@ protected void commitIndexWriter(IndexWriter writer, Translog translog) throws I final long localCheckpoint = Long.parseLong( engine.getLastCommittedSegmentInfos().userData.get(SequenceNumbers.LOCAL_CHECKPOINT_KEY) ); - final long committedGen = engine.getTranslog().getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration; + final long committedGen = engine.getTranslog().getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration(); for (int gen = 1; gen < committedGen; gen++) { final Path genFile = translogPath.resolve(Translog.getFilename(gen)); assertFalse(genFile + " wasn't cleaned up", Files.exists(genFile)); @@ -3601,7 +3601,7 @@ public void testRecoverFromForeignTranslog() throws IOException { seqNo -> {} ); translog.add(TranslogOperationsUtils.indexOp("SomeBogusId", 0, primaryTerm.get())); - assertEquals(generation.translogFileGeneration, translog.currentFileGeneration()); + assertEquals(generation.translogFileGeneration(), translog.currentFileGeneration()); translog.close(); EngineConfig config = engine.config(); @@ -5232,7 +5232,7 @@ public void testMinGenerationForSeqNo() throws IOException, BrokenBarrierExcepti * This sequence number landed in the last generation, but the lower and upper bounds for an earlier generation straddle * this sequence number. */ - assertThat(translog.getMinGenerationForSeqNo(3 * i + 1).translogFileGeneration, equalTo(i + generation)); + assertThat(translog.getMinGenerationForSeqNo(3 * i + 1).translogFileGeneration(), equalTo(i + generation)); } int i = 0; @@ -5855,7 +5855,7 @@ public void testShouldPeriodicallyFlushOnSize() throws Exception { final Translog translog = engine.getTranslog(); final IntSupplier uncommittedTranslogOperationsSinceLastCommit = () -> { long localCheckpoint = Long.parseLong(engine.getLastCommittedSegmentInfos().userData.get(SequenceNumbers.LOCAL_CHECKPOINT_KEY)); - return translog.totalOperationsByMinGen(translog.getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration); + return translog.totalOperationsByMinGen(translog.getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration()); }; final long extraTranslogSizeInNewEngine = engine.getTranslog().stats().getUncommittedSizeInBytes() - Translog.DEFAULT_HEADER_SIZE_IN_BYTES; @@ -7417,7 +7417,7 @@ public void testMaxDocsOnPrimary() throws Exception { assertNotNull(result.getFailure()); assertThat( result.getFailure().getMessage(), - containsString("Number of documents in the index can't exceed [" + maxDocs + "]") + containsString("Number of documents in the shard cannot exceed [" + maxDocs + "]") ); assertThat(result.getSeqNo(), equalTo(UNASSIGNED_SEQ_NO)); assertThat(engine.getLocalCheckpointTracker().getMaxSeqNo(), equalTo(maxSeqNo)); diff --git a/server/src/test/java/org/elasticsearch/index/mapper/DocumentParserContextTests.java b/server/src/test/java/org/elasticsearch/index/mapper/DocumentParserContextTests.java index ab1c93cd98277..2826243e4c866 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/DocumentParserContextTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/DocumentParserContextTests.java @@ -11,7 +11,9 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; import java.io.IOException; @@ -81,4 +83,54 @@ public void testSwitchParser() throws IOException { assertEquals(parser, newContext.parser()); assertEquals("1", newContext.indexSettings().getSettings().get("index.mapping.total_fields.limit")); } + + public void testCreateDynamicMapperBuilderContextFromEmptyContext() throws IOException { + var resultFromEmptyParserContext = context.createDynamicMapperBuilderContext(); + + assertEquals("hey", resultFromEmptyParserContext.buildFullName("hey")); + assertFalse(resultFromEmptyParserContext.isSourceSynthetic()); + assertFalse(resultFromEmptyParserContext.isDataStream()); + assertFalse(resultFromEmptyParserContext.parentObjectContainsDimensions()); + assertEquals(ObjectMapper.Defaults.DYNAMIC, resultFromEmptyParserContext.getDynamic()); + assertEquals(MapperService.MergeReason.MAPPING_UPDATE, resultFromEmptyParserContext.getMergeReason()); + assertFalse(resultFromEmptyParserContext.isInNestedContext()); + } + + public void testCreateDynamicMapperBuilderContext() throws IOException { + var mapping = XContentBuilder.builder(XContentType.JSON.xContent()) + .startObject() + .startObject("_doc") + .startObject("_source") + .field("mode", "synthetic") + .endObject() + .startObject(DataStreamTimestampFieldMapper.NAME) + .field("enabled", "true") + .endObject() + .startObject("properties") + .startObject(DataStreamTimestampFieldMapper.DEFAULT_PATH) + .field("type", "date") + .endObject() + .startObject("foo") + .field("type", "passthrough") + .field("time_series_dimension", "true") + .field("priority", "100") + .endObject() + .endObject() + .endObject() + .endObject(); + var documentMapper = new MapperServiceTestCase() { + }.createDocumentMapper(mapping); + var parserContext = new TestDocumentParserContext(documentMapper.mappers(), null); + parserContext.path().add("foo"); + + var resultFromParserContext = parserContext.createDynamicMapperBuilderContext(); + + assertEquals("foo.hey", resultFromParserContext.buildFullName("hey")); + assertTrue(resultFromParserContext.isSourceSynthetic()); + assertTrue(resultFromParserContext.isDataStream()); + assertTrue(resultFromParserContext.parentObjectContainsDimensions()); + assertEquals(ObjectMapper.Defaults.DYNAMIC, resultFromParserContext.getDynamic()); + assertEquals(MapperService.MergeReason.MAPPING_UPDATE, resultFromParserContext.getMergeReason()); + assertFalse(resultFromParserContext.isInNestedContext()); + } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/GeoShapeFieldMapperTests.java b/server/src/test/java/org/elasticsearch/index/mapper/GeoShapeFieldMapperTests.java deleted file mode 100644 index 92da99bc059a2..0000000000000 --- a/server/src/test/java/org/elasticsearch/index/mapper/GeoShapeFieldMapperTests.java +++ /dev/null @@ -1,229 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ -package org.elasticsearch.index.mapper; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.geo.Orientation; -import org.elasticsearch.plugins.Plugin; -import org.elasticsearch.test.TestGeoShapeFieldMapperPlugin; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContentBuilder; -import org.junit.AssumptionViolatedException; - -import java.io.IOException; -import java.util.Collection; -import java.util.Collections; -import java.util.List; - -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasSize; -import static org.hamcrest.Matchers.instanceOf; - -public class GeoShapeFieldMapperTests extends MapperTestCase { - - @Override - protected void registerParameters(ParameterChecker checker) throws IOException { - checker.registerUpdateCheck(b -> b.field("orientation", "right"), m -> { - GeoShapeFieldMapper gsfm = (GeoShapeFieldMapper) m; - assertEquals(Orientation.RIGHT, gsfm.orientation()); - }); - checker.registerUpdateCheck(b -> b.field("ignore_z_value", false), m -> { - GeoShapeFieldMapper gpfm = (GeoShapeFieldMapper) m; - assertFalse(gpfm.ignoreZValue()); - }); - checker.registerUpdateCheck(b -> b.field("coerce", true), m -> { - GeoShapeFieldMapper gpfm = (GeoShapeFieldMapper) m; - assertTrue(gpfm.coerce.value()); - }); - } - - @Override - protected Collection getPlugins() { - return List.of(new TestGeoShapeFieldMapperPlugin()); - } - - @Override - protected void minimalMapping(XContentBuilder b) throws IOException { - b.field("type", "geo_shape"); - } - - @Override - protected boolean supportsStoredFields() { - return false; - } - - @Override - protected Object getSampleValueForDocument() { - return "POINT (14.0 15.0)"; - } - - public void testDefaultConfiguration() throws IOException { - DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); - Mapper fieldMapper = mapper.mappers().getMapper("field"); - assertThat(fieldMapper, instanceOf(GeoShapeFieldMapper.class)); - GeoShapeFieldMapper geoShapeFieldMapper = (GeoShapeFieldMapper) fieldMapper; - assertThat(geoShapeFieldMapper.fieldType().orientation(), equalTo(Orientation.RIGHT)); - assertThat(geoShapeFieldMapper.fieldType().hasDocValues(), equalTo(false)); - } - - /** - * Test that orientation parameter correctly parses - */ - public void testOrientationParsing() throws IOException { - DocumentMapper mapper = createDocumentMapper(fieldMapping(b -> b.field("type", "geo_shape").field("orientation", "left"))); - Mapper fieldMapper = mapper.mappers().getMapper("field"); - assertThat(fieldMapper, instanceOf(GeoShapeFieldMapper.class)); - - Orientation orientation = ((GeoShapeFieldMapper) fieldMapper).fieldType().orientation(); - assertThat(orientation, equalTo(Orientation.CLOCKWISE)); - assertThat(orientation, equalTo(Orientation.LEFT)); - assertThat(orientation, equalTo(Orientation.CW)); - - // explicit right orientation test - mapper = createDocumentMapper(fieldMapping(b -> b.field("type", "geo_shape").field("orientation", "right"))); - fieldMapper = mapper.mappers().getMapper("field"); - assertThat(fieldMapper, instanceOf(GeoShapeFieldMapper.class)); - - orientation = ((GeoShapeFieldMapper) fieldMapper).fieldType().orientation(); - assertThat(orientation, equalTo(Orientation.COUNTER_CLOCKWISE)); - assertThat(orientation, equalTo(Orientation.RIGHT)); - assertThat(orientation, equalTo(Orientation.CCW)); - } - - /** - * Test that coerce parameter correctly parses - */ - public void testCoerceParsing() throws IOException { - DocumentMapper mapper = createDocumentMapper(fieldMapping(b -> b.field("type", "geo_shape").field("coerce", true))); - Mapper fieldMapper = mapper.mappers().getMapper("field"); - assertThat(fieldMapper, instanceOf(GeoShapeFieldMapper.class)); - boolean coerce = ((GeoShapeFieldMapper) fieldMapper).coerce(); - assertThat(coerce, equalTo(true)); - - // explicit false coerce test - mapper = createDocumentMapper(fieldMapping(b -> b.field("type", "geo_shape").field("coerce", false))); - fieldMapper = mapper.mappers().getMapper("field"); - assertThat(fieldMapper, instanceOf(GeoShapeFieldMapper.class)); - coerce = ((GeoShapeFieldMapper) fieldMapper).coerce(); - assertThat(coerce, equalTo(false)); - } - - /** - * Test that accept_z_value parameter correctly parses - */ - public void testIgnoreZValue() throws IOException { - DocumentMapper mapper = createDocumentMapper(fieldMapping(b -> b.field("type", "geo_shape").field("ignore_z_value", true))); - Mapper fieldMapper = mapper.mappers().getMapper("field"); - assertThat(fieldMapper, instanceOf(GeoShapeFieldMapper.class)); - - boolean ignoreZValue = ((GeoShapeFieldMapper) fieldMapper).ignoreZValue(); - assertThat(ignoreZValue, equalTo(true)); - - // explicit false accept_z_value test - mapper = createDocumentMapper(fieldMapping(b -> b.field("type", "geo_shape").field("ignore_z_value", false))); - fieldMapper = mapper.mappers().getMapper("field"); - assertThat(fieldMapper, instanceOf(GeoShapeFieldMapper.class)); - - ignoreZValue = ((GeoShapeFieldMapper) fieldMapper).ignoreZValue(); - assertThat(ignoreZValue, equalTo(false)); - } - - @Override - protected boolean supportsIgnoreMalformed() { - return true; - } - - @Override - protected List exampleMalformedValues() { - return List.of( - exampleMalformedValue("Bad shape").errorMatches("Unknown geometry type: bad"), - exampleMalformedValue( - "POLYGON ((18.9401790919516 -33.9681188869036, 18.9401790919516 -33.9681188869036, 18.9401790919517 " - + "-33.9681188869036, 18.9401790919517 -33.9681188869036, 18.9401790919516 -33.9681188869036))" - ).errorMatches("at least three non-collinear points required") - ); - } - - public void testGeoShapeMapperMerge() throws Exception { - MapperService mapperService = createMapperService(fieldMapping(b -> b.field("type", "geo_shape").field("orientation", "ccw"))); - Mapper fieldMapper = mapperService.documentMapper().mappers().getMapper("field"); - assertThat(fieldMapper, instanceOf(GeoShapeFieldMapper.class)); - GeoShapeFieldMapper geoShapeFieldMapper = (GeoShapeFieldMapper) fieldMapper; - assertThat(geoShapeFieldMapper.fieldType().orientation(), equalTo(Orientation.CCW)); - - // change mapping; orientation - merge(mapperService, fieldMapping(b -> b.field("type", "geo_shape").field("orientation", "cw"))); - fieldMapper = mapperService.documentMapper().mappers().getMapper("field"); - assertThat(fieldMapper, instanceOf(GeoShapeFieldMapper.class)); - geoShapeFieldMapper = (GeoShapeFieldMapper) fieldMapper; - assertThat(geoShapeFieldMapper.fieldType().orientation(), equalTo(Orientation.CW)); - } - - public void testSerializeDefaults() throws Exception { - DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); - assertThat( - Strings.toString( - mapper.mappers().getMapper("field"), - new ToXContent.MapParams(Collections.singletonMap("include_defaults", "true")) - ), - containsString("\"orientation\":\"" + Orientation.RIGHT + "\"") - ); - } - - public void testGeoShapeArrayParsing() throws Exception { - DocumentMapper mapper = createDocumentMapper(fieldMapping(this::minimalMapping)); - ParsedDocument document = mapper.parse(source(b -> { - b.startArray("field"); - { - b.startObject().field("type", "Point").startArray("coordinates").value(176.0).value(15.0).endArray().endObject(); - b.startObject().field("type", "Point").startArray("coordinates").value(76.0).value(-15.0).endArray().endObject(); - } - b.endArray(); - })); - assertThat(document.docs(), hasSize(1)); - assertThat(document.docs().get(0).getFields("field"), hasSize(2)); - } - - public void testMultiFieldsDeprecationWarning() throws Exception { - createDocumentMapper(fieldMapping(b -> { - minimalMapping(b); - b.startObject("fields"); - b.startObject("keyword").field("type", "keyword").endObject(); - b.endObject(); - })); - assertWarnings("Adding multifields to [geo_shape] mappers has no effect and will be forbidden in future"); - } - - @Override - protected boolean supportsMeta() { - return false; - } - - protected void assertSearchable(MappedFieldType fieldType) { - // always searchable even if it uses TextSearchInfo.NONE - assertTrue(fieldType.isIndexed()); - assertTrue(fieldType.isSearchable()); - } - - @Override - protected Object generateRandomInputValue(MappedFieldType ft) { - assumeFalse("Test implemented in a follow up", true); - return null; - } - - @Override - protected SyntheticSourceSupport syntheticSourceSupport(boolean ignoreMalformed) { - throw new AssumptionViolatedException("not supported"); - } - - @Override - protected IngestScriptSupport ingestScriptSupport() { - throw new AssumptionViolatedException("not supported"); - } -} diff --git a/server/src/test/java/org/elasticsearch/index/mapper/GeoShapeFieldTypeTests.java b/server/src/test/java/org/elasticsearch/index/mapper/GeoShapeFieldTypeTests.java deleted file mode 100644 index b4dce62d16f37..0000000000000 --- a/server/src/test/java/org/elasticsearch/index/mapper/GeoShapeFieldTypeTests.java +++ /dev/null @@ -1,68 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.index.mapper; - -import java.io.IOException; -import java.util.List; -import java.util.Map; - -public class GeoShapeFieldTypeTests extends FieldTypeTestCase { - - public void testFetchSourceValue() throws IOException { - MappedFieldType mapper = new GeoShapeFieldMapper.Builder("field", true, true).build(MapperBuilderContext.root(false, false)) - .fieldType(); - - Map jsonLineString = Map.of("type", "LineString", "coordinates", List.of(List.of(42.0, 27.1), List.of(30.0, 50.0))); - Map jsonPoint = Map.of("type", "Point", "coordinates", List.of(14.0, 15.0)); - Map jsonMalformed = Map.of("type", "Point", "coordinates", "foo"); - String wktLineString = "LINESTRING (42.0 27.1, 30.0 50.0)"; - String wktPoint = "POINT (14.0 15.0)"; - String wktMalformed = "POINT foo"; - - // Test a single shape in geojson format. - Object sourceValue = jsonLineString; - assertEquals(List.of(jsonLineString), fetchSourceValue(mapper, sourceValue, null)); - assertEquals(List.of(wktLineString), fetchSourceValue(mapper, sourceValue, "wkt")); - - // Test a malformed single shape in geojson format - sourceValue = jsonMalformed; - assertEquals(List.of(), fetchSourceValue(mapper, sourceValue, null)); - assertEquals(List.of(), fetchSourceValue(mapper, sourceValue, "wkt")); - - // Test a list of shapes in geojson format. - sourceValue = List.of(jsonLineString, jsonPoint); - assertEquals(List.of(jsonLineString, jsonPoint), fetchSourceValue(mapper, sourceValue, null)); - assertEquals(List.of(wktLineString, wktPoint), fetchSourceValue(mapper, sourceValue, "wkt")); - - // Test a list of shapes including one malformed in geojson format - sourceValue = List.of(jsonLineString, jsonMalformed, jsonPoint); - assertEquals(List.of(jsonLineString, jsonPoint), fetchSourceValue(mapper, sourceValue, null)); - assertEquals(List.of(wktLineString, wktPoint), fetchSourceValue(mapper, sourceValue, "wkt")); - - // Test a single shape in wkt format. - sourceValue = wktLineString; - assertEquals(List.of(jsonLineString), fetchSourceValue(mapper, sourceValue, null)); - assertEquals(List.of(wktLineString), fetchSourceValue(mapper, sourceValue, "wkt")); - - // Test a single malformed shape in wkt format - sourceValue = wktMalformed; - assertEquals(List.of(), fetchSourceValue(mapper, sourceValue, null)); - assertEquals(List.of(), fetchSourceValue(mapper, sourceValue, "wkt")); - - // Test a list of shapes in wkt format. - sourceValue = List.of(wktLineString, wktPoint); - assertEquals(List.of(jsonLineString, jsonPoint), fetchSourceValue(mapper, sourceValue, null)); - assertEquals(List.of(wktLineString, wktPoint), fetchSourceValue(mapper, sourceValue, "wkt")); - - // Test a list of shapes including one malformed in wkt format - sourceValue = List.of(wktLineString, wktMalformed, wktPoint); - assertEquals(List.of(jsonLineString, jsonPoint), fetchSourceValue(mapper, sourceValue, null)); - assertEquals(List.of(wktLineString, wktPoint), fetchSourceValue(mapper, sourceValue, "wkt")); - } -} diff --git a/server/src/test/java/org/elasticsearch/index/replication/RecoveryDuringReplicationTests.java b/server/src/test/java/org/elasticsearch/index/replication/RecoveryDuringReplicationTests.java index ff6b27924404e..7d018c23597b7 100644 --- a/server/src/test/java/org/elasticsearch/index/replication/RecoveryDuringReplicationTests.java +++ b/server/src/test/java/org/elasticsearch/index/replication/RecoveryDuringReplicationTests.java @@ -272,11 +272,11 @@ public void testRecoveryAfterPrimaryPromotion() throws Exception { assertThat(newReplica.recoveryState().getIndex().fileDetails(), empty()); assertThat( newReplica.recoveryState().getTranslog().totalLocal(), - equalTo(Math.toIntExact(globalCheckpointOnOldPrimary - safeCommitOnOldPrimary.get().localCheckpoint)) + equalTo(Math.toIntExact(globalCheckpointOnOldPrimary - safeCommitOnOldPrimary.get().localCheckpoint())) ); assertThat( newReplica.recoveryState().getTranslog().recoveredOperations(), - equalTo(Math.toIntExact(totalDocs - 1 - safeCommitOnOldPrimary.get().localCheckpoint)) + equalTo(Math.toIntExact(totalDocs - 1 - safeCommitOnOldPrimary.get().localCheckpoint())) ); } else { assertThat(newReplica.recoveryState().getIndex().fileDetails(), not(empty())); diff --git a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java index 9d53b95e01db3..29f39134d2bcf 100644 --- a/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java +++ b/server/src/test/java/org/elasticsearch/index/shard/IndexShardTests.java @@ -659,7 +659,7 @@ public void testPrimaryFillsSeqNoGapsOnPromotion() throws Exception { public void testPrimaryPromotionRollsGeneration() throws Exception { final IndexShard indexShard = newStartedShard(false); - final long currentTranslogGeneration = getTranslog(indexShard).getGeneration().translogFileGeneration; + final long currentTranslogGeneration = getTranslog(indexShard).getGeneration().translogFileGeneration(); // promote the replica final ShardRouting replicaRouting = indexShard.routingEntry(); @@ -698,7 +698,7 @@ public void onFailure(Exception e) { }, threadPool.generic()); latch.await(); - assertThat(getTranslog(indexShard).getGeneration().translogFileGeneration, equalTo(currentTranslogGeneration + 1)); + assertThat(getTranslog(indexShard).getGeneration().translogFileGeneration(), equalTo(currentTranslogGeneration + 1)); assertThat(TestTranslog.getCurrentTerm(getTranslog(indexShard)), equalTo(newPrimaryTerm)); closeShards(indexShard); @@ -995,7 +995,7 @@ public void testOperationPermitOnReplicaShards() throws Exception { } final long primaryTerm = indexShard.getPendingPrimaryTerm(); - final long translogGen = engineClosed ? -1 : getTranslog(indexShard).getGeneration().translogFileGeneration; + final long translogGen = engineClosed ? -1 : getTranslog(indexShard).getGeneration().translogFileGeneration(); final Releasable operation1; final Releasable operation2; @@ -1115,7 +1115,7 @@ private void finish() { assertTrue(onResponse.get()); assertNull(onFailure.get()); assertThat( - getTranslog(indexShard).getGeneration().translogFileGeneration, + getTranslog(indexShard).getGeneration().translogFileGeneration(), // if rollback happens we roll translog twice: one when we flush a commit before opening a read-only engine // and one after replaying translog (upto the global checkpoint); otherwise we roll translog once. either(equalTo(translogGen + 1)).or(equalTo(translogGen + 2)) diff --git a/server/src/test/java/org/elasticsearch/index/store/FsDirectoryFactoryTests.java b/server/src/test/java/org/elasticsearch/index/store/FsDirectoryFactoryTests.java index 8bc90b3000dc8..b1222213a505d 100644 --- a/server/src/test/java/org/elasticsearch/index/store/FsDirectoryFactoryTests.java +++ b/server/src/test/java/org/elasticsearch/index/store/FsDirectoryFactoryTests.java @@ -8,12 +8,9 @@ package org.elasticsearch.index.store; import org.apache.lucene.store.AlreadyClosedException; -import org.apache.lucene.store.ByteBuffersDirectory; import org.apache.lucene.store.Directory; import org.apache.lucene.store.FilterDirectory; import org.apache.lucene.store.IOContext; -import org.apache.lucene.store.IndexInput; -import org.apache.lucene.store.IndexOutput; import org.apache.lucene.store.MMapDirectory; import org.apache.lucene.store.NIOFSDirectory; import org.apache.lucene.store.NoLockFactory; @@ -69,29 +66,6 @@ public void testPreload() throws IOException { } } - public void testDisableRandomAdvice() throws IOException { - Directory dir = new FilterDirectory(new ByteBuffersDirectory()) { - @Override - public IndexInput openInput(String name, IOContext context) throws IOException { - assertFalse(context.randomAccess); - return super.openInput(name, context); - } - }; - Directory noRandomAccessDir = FsDirectoryFactory.disableRandomAdvice(dir); - try (IndexOutput out = noRandomAccessDir.createOutput("foo", IOContext.DEFAULT)) { - out.writeInt(42); - } - // Test the tester - expectThrows(AssertionError.class, () -> dir.openInput("foo", IOContext.RANDOM)); - - // The wrapped directory shouldn't fail regardless of the IOContext - for (IOContext context : Arrays.asList(IOContext.READ, IOContext.DEFAULT, IOContext.READONCE, IOContext.RANDOM)) { - try (IndexInput in = noRandomAccessDir.openInput("foo", context)) { - assertEquals(42, in.readInt()); - } - } - } - private Directory newDirectory(Settings settings) throws IOException { IndexSettings idxSettings = IndexSettingsModule.newIndexSettings("foo", settings); Path tempDir = createTempDir().resolve(idxSettings.getUUID()).resolve("0"); diff --git a/server/src/test/java/org/elasticsearch/index/translog/TranslogTests.java b/server/src/test/java/org/elasticsearch/index/translog/TranslogTests.java index cd7e637d58bcc..8a277e400ad6c 100644 --- a/server/src/test/java/org/elasticsearch/index/translog/TranslogTests.java +++ b/server/src/test/java/org/elasticsearch/index/translog/TranslogTests.java @@ -1250,7 +1250,7 @@ public void testLocationComparison() throws IOException { max = max(max, location); } - assertEquals(max.generation, translog.currentFileGeneration()); + assertEquals(max.generation(), translog.currentFileGeneration()); try (Translog.Snapshot snap = new SortedSnapshot(translog.newSnapshot())) { Translog.Operation next; Translog.Operation maxOp = null; @@ -1655,17 +1655,17 @@ public void testTranslogOperationListener() throws IOException { try (Translog translog = createTranslog(config)) { Location location1 = translog.add(indexOp(randomAlphaOfLength(10), 0, primaryTerm.get())); Location location2 = translog.add(TranslogOperationsUtils.indexOp(randomAlphaOfLength(10), 1, primaryTerm.get())); - long firstGeneration = translog.getGeneration().translogFileGeneration; - assertThat(location1.generation, equalTo(firstGeneration)); - assertThat(location2.generation, equalTo(firstGeneration)); + long firstGeneration = translog.getGeneration().translogFileGeneration(); + assertThat(location1.generation(), equalTo(firstGeneration)); + assertThat(location2.generation(), equalTo(firstGeneration)); translog.rollGeneration(); Location location3 = translog.add(TranslogOperationsUtils.indexOp(randomAlphaOfLength(10), 3, primaryTerm.get())); Location location4 = translog.add(TranslogOperationsUtils.indexOp(randomAlphaOfLength(10), 2, primaryTerm.get())); - long secondGeneration = translog.getGeneration().translogFileGeneration; - assertThat(location3.generation, equalTo(secondGeneration)); - assertThat(location4.generation, equalTo(secondGeneration)); + long secondGeneration = translog.getGeneration().translogFileGeneration(); + assertThat(location3.generation(), equalTo(secondGeneration)); + assertThat(location4.generation(), equalTo(secondGeneration)); assertThat(seqNos, equalTo(List.of(0L, 1L, 3L, 2L))); assertThat(locations, equalTo(List.of(location1, location2, location3, location4))); @@ -1741,7 +1741,7 @@ public void testBasicRecovery() throws IOException { } else { translog = new Translog( config, - translogGeneration.translogUUID, + translogGeneration.translogUUID(), translog.getDeletionPolicy(), () -> SequenceNumbers.NO_OPS_PERFORMED, primaryTerm::get, @@ -1749,7 +1749,7 @@ public void testBasicRecovery() throws IOException { ); assertEquals( "lastCommitted must be 1 less than current", - translogGeneration.translogFileGeneration + 1, + translogGeneration.translogFileGeneration() + 1, translog.currentFileGeneration() ); assertFalse(translog.syncNeeded()); @@ -1758,7 +1758,7 @@ public void testBasicRecovery() throws IOException { assertEquals( "expected operation" + i + " to be in the previous translog but wasn't", translog.currentFileGeneration() - 1, - locations.get(i).generation + locations.get(i).generation() ); Translog.Operation next = snapshot.next(); assertNotNull("operation " + i + " must be non-null", next); @@ -1782,9 +1782,9 @@ public void testRecoveryUncommitted() throws IOException { assertEquals( "expected this to be the first roll (1 gen is on creation, 2 when opened)", 2L, - translogGeneration.translogFileGeneration + translogGeneration.translogFileGeneration() ); - assertNotNull(translogGeneration.translogUUID); + assertNotNull(translogGeneration.translogUUID()); } } if (sync) { @@ -1808,7 +1808,7 @@ public void testRecoveryUncommitted() throws IOException { assertNotNull(translogGeneration); assertEquals( "lastCommitted must be 2 less than current - we never finished the commit", - translogGeneration.translogFileGeneration + 2, + translogGeneration.translogFileGeneration() + 2, translog.currentFileGeneration() ); assertFalse(translog.syncNeeded()); @@ -1835,7 +1835,7 @@ public void testRecoveryUncommitted() throws IOException { assertNotNull(translogGeneration); assertEquals( "lastCommitted must be 3 less than current - we never finished the commit and run recovery twice", - translogGeneration.translogFileGeneration + 3, + translogGeneration.translogFileGeneration() + 3, translog.currentFileGeneration() ); assertFalse(translog.syncNeeded()); @@ -1869,9 +1869,9 @@ public void testRecoveryUncommittedFileExists() throws IOException { assertEquals( "expected this to be the first roll (1 gen is on creation, 2 when opened)", 2L, - translogGeneration.translogFileGeneration + translogGeneration.translogFileGeneration() ); - assertNotNull(translogGeneration.translogUUID); + assertNotNull(translogGeneration.translogUUID()); } } if (sync) { @@ -1899,7 +1899,7 @@ public void testRecoveryUncommittedFileExists() throws IOException { assertNotNull(translogGeneration); assertEquals( "lastCommitted must be 2 less than current - we never finished the commit", - translogGeneration.translogFileGeneration + 2, + translogGeneration.translogFileGeneration() + 2, translog.currentFileGeneration() ); assertFalse(translog.syncNeeded()); @@ -1927,7 +1927,7 @@ public void testRecoveryUncommittedFileExists() throws IOException { assertNotNull(translogGeneration); assertEquals( "lastCommitted must be 3 less than current - we never finished the commit and run recovery twice", - translogGeneration.translogFileGeneration + 3, + translogGeneration.translogFileGeneration() + 3, translog.currentFileGeneration() ); assertFalse(translog.syncNeeded()); @@ -1960,9 +1960,9 @@ public void testRecoveryUncommittedCorruptedCheckpoint() throws IOException { assertEquals( "expected this to be the first roll (1 gen is on creation, 2 when opened)", 2L, - translogGeneration.translogFileGeneration + translogGeneration.translogFileGeneration() ); - assertNotNull(translogGeneration.translogUUID); + assertNotNull(translogGeneration.translogUUID()); } } translog.sync(); @@ -2015,7 +2015,7 @@ public void testRecoveryUncommittedCorruptedCheckpoint() throws IOException { assertNotNull(translogGeneration); assertEquals( "lastCommitted must be 2 less than current - we never finished the commit", - translogGeneration.translogFileGeneration + 2, + translogGeneration.translogFileGeneration() + 2, translog.currentFileGeneration() ); assertFalse(translog.syncNeeded()); @@ -2284,7 +2284,7 @@ public void testOpenForeignTranslog() throws IOException { Translog.TranslogGeneration translogGeneration = translog.getGeneration(); translog.close(); - final String foreignTranslog = randomRealisticUnicodeOfCodepointLengthBetween(1, translogGeneration.translogUUID.length()); + final String foreignTranslog = randomRealisticUnicodeOfCodepointLengthBetween(1, translogGeneration.translogUUID().length()); try { new Translog( config, @@ -2507,7 +2507,7 @@ public void testFailFlush() throws IOException { ) { assertEquals( "lastCommitted must be 1 less than current", - translogGeneration.translogFileGeneration + 1, + translogGeneration.translogFileGeneration() + 1, tlog.currentFileGeneration() ); assertFalse(tlog.syncNeeded()); @@ -2518,7 +2518,7 @@ public void testFailFlush() throws IOException { assertEquals( "expected operation" + i + " to be in the previous translog but wasn't", tlog.currentFileGeneration() - 1, - locations.get(i).generation + locations.get(i).generation() ); Translog.Operation next = snapshot.next(); assertNotNull("operation " + i + " must be non-null", next); @@ -2540,7 +2540,7 @@ public void testTranslogOpsCountIsCorrect() throws IOException { assertEquals( "expected operation" + i + " to be in the current translog but wasn't", translog.currentFileGeneration(), - locations.get(i).generation + locations.get(i).generation() ); Translog.Operation next = snapshot.next(); assertNotNull("operation " + i + " must be non-null", next); @@ -2640,7 +2640,7 @@ protected void afterAdd() throws IOException { assertFalse(translog.isOpen()); final Checkpoint checkpoint = Checkpoint.read(config.getTranslogPath().resolve(Translog.CHECKPOINT_FILE_NAME)); // drop all that haven't been synced - writtenOperations.removeIf(next -> checkpoint.offset < (next.location.translogLocation + next.location.size)); + writtenOperations.removeIf(next -> checkpoint.offset < (next.location.translogLocation() + next.location.size())); try ( Translog tlog = new Translog( config, @@ -2664,7 +2664,7 @@ protected void afterAdd() throws IOException { assertEquals( "expected operation" + i + " to be in the previous translog but wasn't", tlog.currentFileGeneration() - 1, - writtenOperations.get(i).location.generation + writtenOperations.get(i).location.generation() ); Translog.Operation next = snapshot.next(); assertNotNull("operation " + i + " must be non-null", next); @@ -2695,7 +2695,7 @@ public void testRecoveryFromAFutureGenerationCleansUp() throws IOException { translog.rollGeneration(); } } - long minRetainedGen = translog.getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration; + long minRetainedGen = translog.getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration(); // engine blows up, after committing the above generation translog.close(); TranslogConfig config = translog.getConfig(); @@ -2753,7 +2753,7 @@ public void testRecoveryFromFailureOnTrimming() throws IOException { } } deletionPolicy.setLocalCheckpointOfSafeCommit(localCheckpoint); - minGenForRecovery = translog.getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration; + minGenForRecovery = translog.getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration(); fail.failRandomly(); try { translog.trimUnreferencedReaders(); @@ -2777,7 +2777,7 @@ public void testRecoveryFromFailureOnTrimming() throws IOException { assertThat(translog.getMinFileGeneration(), greaterThanOrEqualTo(1L)); assertThat(translog.getMinFileGeneration(), lessThanOrEqualTo(minGenForRecovery)); assertFilePresences(translog); - minGenForRecovery = translog.getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration; + minGenForRecovery = translog.getMinGenerationForSeqNo(localCheckpoint + 1).translogFileGeneration(); translog.trimUnreferencedReaders(); assertThat(translog.getMinFileGeneration(), equalTo(minGenForRecovery)); assertFilePresences(translog); @@ -3539,7 +3539,7 @@ public void testMinSeqNoBasedAPI() throws IOException { translog.rollGeneration(); for (long seqNo = 0; seqNo < operations; seqNo++) { final Set> seenSeqNos = new HashSet<>(); - final long generation = translog.getMinGenerationForSeqNo(seqNo).translogFileGeneration; + final long generation = translog.getMinGenerationForSeqNo(seqNo).translogFileGeneration(); int expectedSnapshotOps = 0; for (long g = generation; g < translog.currentFileGeneration(); g++) { if (seqNoPerGeneration.containsKey(g) == false) { @@ -3924,7 +3924,7 @@ public void testSyncConcurrently() throws Exception { assertThat("seq# " + op.seqNo() + " was not marked as persisted", persistedSeqNos, hasItem(op.seqNo())); } Checkpoint checkpoint = translog.getLastSyncedCheckpoint(); - assertThat(checkpoint.offset, greaterThanOrEqualTo(location.translogLocation)); + assertThat(checkpoint.offset, greaterThanOrEqualTo(location.translogLocation())); for (Translog.Operation op : ops) { assertThat(checkpoint.minSeqNo, lessThanOrEqualTo(op.seqNo())); assertThat(checkpoint.maxSeqNo, greaterThanOrEqualTo(op.seqNo())); diff --git a/server/src/test/java/org/elasticsearch/indices/recovery/PeerRecoveryTargetServiceTests.java b/server/src/test/java/org/elasticsearch/indices/recovery/PeerRecoveryTargetServiceTests.java index 4266b514bf544..8001c8c901829 100644 --- a/server/src/test/java/org/elasticsearch/indices/recovery/PeerRecoveryTargetServiceTests.java +++ b/server/src/test/java/org/elasticsearch/indices/recovery/PeerRecoveryTargetServiceTests.java @@ -223,8 +223,8 @@ public void testPrepareIndexForPeerRecovery() throws Exception { Optional safeCommit = shard.store().findSafeIndexCommit(globalCheckpoint); assertTrue(safeCommit.isPresent()); int expectedTotalLocal = 0; - if (safeCommit.get().localCheckpoint < globalCheckpoint) { - try (Translog.Snapshot snapshot = getTranslog(shard).newSnapshot(safeCommit.get().localCheckpoint + 1, globalCheckpoint)) { + if (safeCommit.get().localCheckpoint() < globalCheckpoint) { + try (Translog.Snapshot snapshot = getTranslog(shard).newSnapshot(safeCommit.get().localCheckpoint() + 1, globalCheckpoint)) { Translog.Operation op; while ((op = snapshot.next()) != null) { if (op.seqNo() <= globalCheckpoint) { @@ -276,7 +276,7 @@ public void testPrepareIndexForPeerRecovery() throws Exception { replica.markAsRecovering("for testing", new RecoveryState(replica.routingEntry(), localNode, localNode)); replica.prepareForIndexRecovery(); if (safeCommit.isPresent()) { - assertThat(recoverLocallyUpToGlobalCheckpoint(replica), equalTo(safeCommit.get().localCheckpoint + 1)); + assertThat(recoverLocallyUpToGlobalCheckpoint(replica), equalTo(safeCommit.get().localCheckpoint() + 1)); assertThat(replica.recoveryState().getTranslog().totalLocal(), equalTo(0)); } else { assertThat(recoverLocallyUpToGlobalCheckpoint(replica), equalTo(UNASSIGNED_SEQ_NO)); @@ -313,7 +313,7 @@ public void testClosedIndexSkipsLocalRecovery() throws Exception { ); replica.markAsRecovering("for testing", new RecoveryState(replica.routingEntry(), localNode, localNode)); replica.prepareForIndexRecovery(); - assertThat(recoverLocallyUpToGlobalCheckpoint(replica), equalTo(safeCommit.get().localCheckpoint + 1)); + assertThat(recoverLocallyUpToGlobalCheckpoint(replica), equalTo(safeCommit.get().localCheckpoint() + 1)); assertThat(replica.recoveryState().getTranslog().totalLocal(), equalTo(0)); assertThat(replica.recoveryState().getTranslog().recoveredOperations(), equalTo(0)); assertThat(replica.getLastKnownGlobalCheckpoint(), equalTo(UNASSIGNED_SEQ_NO)); diff --git a/server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTests.java b/server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTests.java index fc8f1988a732b..47c9c5e85f7b9 100644 --- a/server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTests.java +++ b/server/src/test/java/org/elasticsearch/indices/recovery/RecoveryTests.java @@ -252,7 +252,7 @@ public void testDifferentHistoryUUIDDisablesOPsRecovery() throws Exception { replica.getPendingPrimaryTerm() ); } else { - translogUUIDtoUse = translogGeneration.translogUUID; + translogUUIDtoUse = translogGeneration.translogUUID(); } try (IndexWriter writer = new IndexWriter(replica.store().directory(), iwc)) { userData.put(Engine.HISTORY_UUID_KEY, historyUUIDtoUse); diff --git a/server/src/test/java/org/elasticsearch/plugins/UberModuleClassLoaderTests.java b/server/src/test/java/org/elasticsearch/plugins/UberModuleClassLoaderTests.java index e3cd11c8f3b68..ecc2f458cdd60 100644 --- a/server/src/test/java/org/elasticsearch/plugins/UberModuleClassLoaderTests.java +++ b/server/src/test/java/org/elasticsearch/plugins/UberModuleClassLoaderTests.java @@ -427,12 +427,12 @@ public String getTestString() { package p; import java.util.ServiceLoader; - import java.util.random.RandomGenerator; + import java.nio.file.spi.FileSystemProvider; public class ServiceCaller { public static String demo() { // check no error if we load a service from the jdk - ServiceLoader randomLoader = ServiceLoader.load(RandomGenerator.class); + ServiceLoader fileSystemLoader = ServiceLoader.load(FileSystemProvider.class); ServiceLoader loader = ServiceLoader.load(MyService.class, ServiceCaller.class.getClassLoader()); return loader.findFirst().get().getTestString(); diff --git a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java index 4609c7327c798..7ddcc88facb2a 100644 --- a/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java +++ b/server/src/test/java/org/elasticsearch/search/SearchServiceTests.java @@ -47,7 +47,10 @@ import org.elasticsearch.cluster.routing.ShardRouting; import org.elasticsearch.cluster.routing.ShardRoutingState; import org.elasticsearch.cluster.routing.TestShardRouting; +import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.UUIDs; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.common.bytes.BytesReference; import org.elasticsearch.common.io.stream.StreamInput; @@ -151,6 +154,7 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Consumer; import java.util.function.Function; +import java.util.function.IntConsumer; import java.util.function.Supplier; import static java.util.Collections.emptyList; @@ -1985,6 +1989,38 @@ public void testCreateReduceContext() { } } + public void testMultiBucketConsumerServiceCB() { + MultiBucketConsumerService service = new MultiBucketConsumerService( + getInstanceFromNode(ClusterService.class), + Settings.EMPTY, + new NoopCircuitBreaker("test") { + + @Override + public void addEstimateBytesAndMaybeBreak(long bytes, String label) throws CircuitBreakingException { + throw new CircuitBreakingException("tripped", getDurability()); + } + } + ); + // for partial + { + IntConsumer consumer = service.createForPartial(); + for (int i = 0; i < 1023; i++) { + consumer.accept(0); + } + CircuitBreakingException ex = expectThrows(CircuitBreakingException.class, () -> consumer.accept(0)); + assertThat(ex.getMessage(), equalTo("tripped")); + } + // for final + { + IntConsumer consumer = service.createForFinal(); + for (int i = 0; i < 1023; i++) { + consumer.accept(0); + } + CircuitBreakingException ex = expectThrows(CircuitBreakingException.class, () -> consumer.accept(0)); + assertThat(ex.getMessage(), equalTo("tripped")); + } + } + public void testCreateSearchContext() throws IOException { String index = randomAlphaOfLengthBetween(5, 10).toLowerCase(Locale.ROOT); IndexService indexService = createIndex(index); diff --git a/server/src/test/java/org/elasticsearch/search/geo/GeoPointShapeQueryTests.java b/server/src/test/java/org/elasticsearch/search/geo/GeoPointShapeQueryTests.java index 779e0ad28433a..af408299c4150 100644 --- a/server/src/test/java/org/elasticsearch/search/geo/GeoPointShapeQueryTests.java +++ b/server/src/test/java/org/elasticsearch/search/geo/GeoPointShapeQueryTests.java @@ -37,7 +37,7 @@ protected SpatialQueryBuilders queryBuilder() { @Override protected String fieldTypeName() { - return "geo_shape"; + return "keyword"; } @Override diff --git a/server/src/test/java/org/elasticsearch/search/geo/GeoShapeQueryTests.java b/server/src/test/java/org/elasticsearch/search/geo/GeoShapeQueryTests.java deleted file mode 100644 index 7cb7d69ea4b6f..0000000000000 --- a/server/src/test/java/org/elasticsearch/search/geo/GeoShapeQueryTests.java +++ /dev/null @@ -1,11 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ - -package org.elasticsearch.search.geo; - -public class GeoShapeQueryTests extends GeoShapeQueryTestCase {} diff --git a/server/src/test/java/org/elasticsearch/search/profile/AbstractProfileBreakdownTests.java b/server/src/test/java/org/elasticsearch/search/profile/AbstractProfileBreakdownTests.java index b8b12357b085e..e988599fccc3b 100644 --- a/server/src/test/java/org/elasticsearch/search/profile/AbstractProfileBreakdownTests.java +++ b/server/src/test/java/org/elasticsearch/search/profile/AbstractProfileBreakdownTests.java @@ -11,7 +11,6 @@ import org.elasticsearch.test.ESTestCase; import java.util.Map; -import java.util.concurrent.CountDownLatch; import static org.hamcrest.Matchers.equalTo; @@ -107,35 +106,21 @@ public void testGetBreakdownAndNodeTime() { public void testMultiThreaded() throws InterruptedException { TestProfileBreakdown testBreakdown = new TestProfileBreakdown(); - Thread[] threads = new Thread[200]; - final CountDownLatch latch = new CountDownLatch(1); + final int threads = 200; int startsPerThread = between(1, 5); - for (int t = 0; t < threads.length; t++) { - final TestTimingTypes timingType = randomFrom(TestTimingTypes.values()); - threads[t] = new Thread(() -> { - try { - latch.await(); - } catch (InterruptedException e) { - throw new RuntimeException(e); - } - Timer timer = testBreakdown.getNewTimer(timingType); - for (int runs = 0; runs < startsPerThread; runs++) { - timer.start(); - timer.stop(); - } - }); - threads[t].start(); - } // starting all threads simultaneously increases the likelihood of failure in case we don't synchronize timer access properly - latch.countDown(); - for (Thread t : threads) { - t.join(); - } + startInParallel(threads, t -> { + final TestTimingTypes timingType = randomFrom(TestTimingTypes.values()); + Timer timer = testBreakdown.getNewTimer(timingType); + for (int runs = 0; runs < startsPerThread; runs++) { + timer.start(); + timer.stop(); + } + }); Map breakdownMap = testBreakdown.toBreakdownMap(); long totalCounter = breakdownMap.get(TestTimingTypes.ONE + "_count") + breakdownMap.get(TestTimingTypes.TWO + "_count") + breakdownMap.get(TestTimingTypes.THREE + "_count"); - assertEquals(threads.length * startsPerThread, totalCounter); - + assertEquals(threads * startsPerThread, totalCounter); } private void runTimerNTimes(Timer t, int n) { diff --git a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java index cbbbe7d86f4e2..de35d765a1551 100644 --- a/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java +++ b/server/src/test/java/org/elasticsearch/search/retriever/KnnRetrieverBuilderParsingTests.java @@ -10,9 +10,14 @@ import org.elasticsearch.action.search.SearchRequest; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.index.query.MatchAllQueryBuilder; +import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.RandomQueryBuilder; +import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.search.SearchModule; +import org.elasticsearch.search.builder.SearchSourceBuilder; import org.elasticsearch.test.AbstractXContentTestCase; import org.elasticsearch.usage.SearchUsage; import org.elasticsearch.xcontent.NamedXContentRegistry; @@ -23,6 +28,10 @@ import java.util.List; import static org.elasticsearch.search.vectors.KnnSearchBuilderTests.randomVector; +import static org.hamcrest.Matchers.anyOf; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; public class KnnRetrieverBuilderParsingTests extends AbstractXContentTestCase { @@ -34,7 +43,7 @@ public class KnnRetrieverBuilderParsingTests extends AbstractXContentTestCase { /** @@ -59,7 +71,7 @@ public static StandardRetrieverBuilder createRandomStandardRetrieverBuilder( } if (randomBoolean()) { - standardRetrieverBuilder.sortBuilders = SortBuilderTests.randomSortBuilderList(); + standardRetrieverBuilder.sortBuilders = SortBuilderTests.randomSortBuilderList(false); } if (randomBoolean()) { @@ -109,4 +121,52 @@ protected String[] getShuffleFieldsExceptions() { protected NamedXContentRegistry xContentRegistry() { return new NamedXContentRegistry(new SearchModule(Settings.EMPTY, List.of()).getNamedXContents()); } + + public void testRewrite() throws IOException { + for (int i = 0; i < 10; i++) { + StandardRetrieverBuilder standardRetriever = createTestInstance(); + SearchSourceBuilder source = new SearchSourceBuilder().retriever(standardRetriever); + QueryRewriteContext queryRewriteContext = mock(QueryRewriteContext.class); + source = Rewriteable.rewrite(source, queryRewriteContext); + assertNull(source.retriever()); + assertTrue(source.knnSearch().isEmpty()); + if (standardRetriever.queryBuilder != null) { + assertNotNull(source.query()); + if (standardRetriever.preFilterQueryBuilders.size() > 0) { + if (source.query() instanceof MatchAllQueryBuilder == false + && source.query() instanceof MatchNoneQueryBuilder == false) { + assertThat(source.query(), instanceOf(BoolQueryBuilder.class)); + BoolQueryBuilder bq = (BoolQueryBuilder) source.query(); + assertFalse(bq.must().isEmpty()); + assertThat(bq.must().size(), equalTo(1)); + assertThat(bq.must().get(0), equalTo(standardRetriever.queryBuilder)); + for (int j = 0; j < bq.filter().size(); j++) { + assertEqualQueryOrMatchAllNone(bq.filter().get(j), standardRetriever.preFilterQueryBuilders.get(j)); + } + } + } else { + assertEqualQueryOrMatchAllNone(source.query(), standardRetriever.queryBuilder); + } + } else if (standardRetriever.preFilterQueryBuilders.size() > 0) { + if (source.query() instanceof MatchAllQueryBuilder == false && source.query() instanceof MatchNoneQueryBuilder == false) { + assertNotNull(source.query()); + assertThat(source.query(), instanceOf(BoolQueryBuilder.class)); + BoolQueryBuilder bq = (BoolQueryBuilder) source.query(); + assertTrue(bq.must().isEmpty()); + for (int j = 0; j < bq.filter().size(); j++) { + assertEqualQueryOrMatchAllNone(bq.filter().get(j), standardRetriever.preFilterQueryBuilders.get(j)); + } + } + } else { + assertNull(source.query()); + } + if (standardRetriever.sortBuilders != null) { + assertThat(source.sorts().size(), equalTo(standardRetriever.sortBuilders.size())); + } + } + } + + private static void assertEqualQueryOrMatchAllNone(QueryBuilder actual, QueryBuilder expected) { + assertThat(actual, anyOf(instanceOf(MatchAllQueryBuilder.class), instanceOf(MatchNoneQueryBuilder.class), equalTo(expected))); + } } diff --git a/server/src/test/java/org/elasticsearch/search/sort/SortBuilderTests.java b/server/src/test/java/org/elasticsearch/search/sort/SortBuilderTests.java index eee98297c7a13..84f87b3f01881 100644 --- a/server/src/test/java/org/elasticsearch/search/sort/SortBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/sort/SortBuilderTests.java @@ -119,7 +119,7 @@ public void testSingleFieldSort() throws IOException { public void testRandomSortBuilders() throws IOException { for (int runs = 0; runs < NUMBER_OF_RUNS; runs++) { Set expectedWarningHeaders = new HashSet<>(); - List> testBuilders = randomSortBuilderList(); + List> testBuilders = randomSortBuilderList(randomBoolean()); XContentBuilder xContentBuilder = XContentFactory.jsonBuilder(); xContentBuilder.startObject(); if (testBuilders.size() > 1) { @@ -171,7 +171,7 @@ public void testRandomSortBuilders() throws IOException { } } - public static List> randomSortBuilderList() { + public static List> randomSortBuilderList(boolean hasPIT) { int size = randomIntBetween(1, 5); List> list = new ArrayList<>(size); for (int i = 0; i < size; i++) { @@ -181,7 +181,7 @@ public static List> randomSortBuilderList() { case 2 -> SortBuilders.fieldSort(FieldSortBuilder.DOC_FIELD_NAME); case 3 -> GeoDistanceSortBuilderTests.randomGeoDistanceSortBuilder(); case 4 -> ScriptSortBuilderTests.randomScriptSortBuilder(); - case 5 -> SortBuilders.pitTiebreaker(); + case 5 -> hasPIT ? SortBuilders.pitTiebreaker() : ScriptSortBuilderTests.randomScriptSortBuilder(); default -> throw new IllegalStateException("unexpected randomization in tests"); }); } diff --git a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java index f5d9f35e34695..f0899384dbc5e 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/AbstractKnnVectorQueryBuilderTestCase.java @@ -10,6 +10,7 @@ import org.apache.lucene.search.BooleanClause; import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.MatchNoDocsQuery; import org.apache.lucene.search.Query; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; @@ -25,6 +26,7 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.QueryRewriteContext; +import org.elasticsearch.index.query.QueryShardException; import org.elasticsearch.index.query.Rewriteable; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.index.query.TermQueryBuilder; @@ -157,8 +159,16 @@ public void testWrongDimension() { public void testNonexistentField() { SearchExecutionContext context = createSearchExecutionContext(); KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null); - IllegalArgumentException e = expectThrows(IllegalArgumentException.class, () -> query.doToQuery(context)); - assertThat(e.getMessage(), containsString("field [nonexistent] does not exist in the mapping")); + context.setAllowUnmappedFields(false); + QueryShardException e = expectThrows(QueryShardException.class, () -> query.doToQuery(context)); + assertThat(e.getMessage(), containsString("No field mapping can be found for the field with name [nonexistent]")); + } + + public void testNonexistentFieldReturnEmpty() throws IOException { + SearchExecutionContext context = createSearchExecutionContext(); + KnnVectorQueryBuilder query = new KnnVectorQueryBuilder("nonexistent", new float[] { 1.0f, 1.0f, 1.0f }, 5, 10, null); + Query queryNone = query.doToQuery(context); + assertThat(queryNone, instanceOf(MatchNoDocsQuery.class)); } public void testWrongFieldType() { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java index 1e77e35b60a4c..5f4fb61718a7e 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/ExactKnnQueryBuilderTests.java @@ -14,17 +14,14 @@ import org.elasticsearch.common.compress.CompressedXContent; import org.elasticsearch.index.IndexVersions; import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.AbstractQueryTestCase; -import org.elasticsearch.test.TestGeoShapeFieldMapperPlugin; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentFactory; import java.io.IOException; import java.util.Arrays; -import java.util.Collection; -import java.util.List; public class ExactKnnQueryBuilderTests extends AbstractQueryTestCase { @@ -50,11 +47,6 @@ protected void initializeAdditionalMappings(MapperService mapperService) throws ); } - @Override - protected Collection> getPlugins() { - return List.of(TestGeoShapeFieldMapperPlugin.class); - } - @Override protected ExactKnnQueryBuilder doCreateTestQueryBuilder() { float[] query = new float[VECTOR_DIMENSION]; @@ -87,7 +79,9 @@ protected void doAssertLuceneQuery(ExactKnnQueryBuilder queryBuilder, Query quer DenseVectorQuery.Floats denseVectorQuery = (DenseVectorQuery.Floats) query; assertEquals(VECTOR_FIELD, denseVectorQuery.field); float[] expected = Arrays.copyOf(queryBuilder.getQuery().asFloatVector(), queryBuilder.getQuery().asFloatVector().length); - if (context.getIndexSettings().getIndexVersionCreated().onOrAfter(IndexVersions.NORMALIZED_VECTOR_COSINE)) { + float magnitude = VectorUtil.dotProduct(expected, expected); + if (context.getIndexSettings().getIndexVersionCreated().onOrAfter(IndexVersions.NORMALIZED_VECTOR_COSINE) + && DenseVectorFieldMapper.isNotUnitVector(magnitude)) { VectorUtil.l2normalize(expected); assertArrayEquals(expected, denseVectorQuery.getQuery(), 0.0f); } else { diff --git a/server/src/test/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilderTests.java b/server/src/test/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilderTests.java index 67bc6bde9c1af..d2a5859ae981f 100644 --- a/server/src/test/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilderTests.java +++ b/server/src/test/java/org/elasticsearch/search/vectors/KnnScoreDocQueryBuilderTests.java @@ -28,14 +28,11 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryRewriteContext; import org.elasticsearch.index.query.SearchExecutionContext; -import org.elasticsearch.plugins.Plugin; import org.elasticsearch.test.AbstractQueryTestCase; -import org.elasticsearch.test.TestGeoShapeFieldMapperPlugin; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; -import java.util.Collection; import java.util.Comparator; import java.util.List; @@ -49,11 +46,6 @@ public class KnnScoreDocQueryBuilderTests extends AbstractQueryTestCase { - @Override - protected Collection> getPlugins() { - return List.of(TestGeoShapeFieldMapperPlugin.class); - } - @Override protected KnnScoreDocQueryBuilder doCreateTestQueryBuilder() { List scoreDocs = new ArrayList<>(); diff --git a/server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java b/server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java index 761d369d6fc39..c5034f51d1e26 100644 --- a/server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java +++ b/server/src/test/java/org/elasticsearch/transport/TransportServiceHandshakeTests.java @@ -46,6 +46,7 @@ import static java.util.Collections.emptySet; import static org.elasticsearch.transport.AbstractSimpleTransportTestCase.IGNORE_DESERIALIZATION_ERRORS_SETTING; +import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; @@ -306,7 +307,18 @@ public void testNodeConnectWithDifferentNodeId() { ConnectTransportException.class, () -> AbstractSimpleTransportTestCase.connectToNode(transportServiceA, discoveryNode, TestProfiles.LIGHT_PROFILE) ); - assertThat(ex.getMessage(), containsString("unexpected remote node")); + assertThat( + ex.getMessage(), + allOf( + containsString("Connecting to [" + discoveryNode.getAddress() + "] failed"), + containsString("expected to connect to [" + discoveryNode.descriptionWithoutAttributes() + "]"), + containsString("found [" + transportServiceB.getLocalNode().descriptionWithoutAttributes() + "] instead"), + containsString("Ensure that each node has its own distinct publish address"), + containsString("routed to the correct node"), + containsString("https://www.elastic.co/guide/en/elasticsearch/reference/"), + containsString("modules-network.html") + ) + ); assertFalse(transportServiceA.nodeConnected(discoveryNode)); } diff --git a/test/framework/src/main/java/org/elasticsearch/index/engine/EngineTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/engine/EngineTestCase.java index 1c7cabb541581..70738c510f62a 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/engine/EngineTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/engine/EngineTestCase.java @@ -125,7 +125,6 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -1179,33 +1178,24 @@ public static void assertOpsOnReplica( } public static void concurrentlyApplyOps(List ops, InternalEngine engine) throws InterruptedException { - Thread[] thread = new Thread[randomIntBetween(3, 5)]; - CountDownLatch startGun = new CountDownLatch(thread.length); + final int threadCount = randomIntBetween(3, 5); AtomicInteger offset = new AtomicInteger(-1); - for (int i = 0; i < thread.length; i++) { - thread[i] = new Thread(() -> { - startGun.countDown(); - safeAwait(startGun); - int docOffset; - while ((docOffset = offset.incrementAndGet()) < ops.size()) { - try { - applyOperation(engine, ops.get(docOffset)); - if ((docOffset + 1) % 4 == 0) { - engine.refresh("test"); - } - if (rarely()) { - engine.flush(); - } - } catch (IOException e) { - throw new AssertionError(e); + startInParallel(threadCount, i -> { + int docOffset; + while ((docOffset = offset.incrementAndGet()) < ops.size()) { + try { + applyOperation(engine, ops.get(docOffset)); + if ((docOffset + 1) % 4 == 0) { + engine.refresh("test"); + } + if (rarely()) { + engine.flush(); } + } catch (IOException e) { + throw new AssertionError(e); } - }); - thread[i].start(); - } - for (int i = 0; i < thread.length; i++) { - thread[i].join(); - } + } + }); } public static void applyOperations(Engine engine, List operations) throws IOException { diff --git a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java index d39a8df80c26d..3bed333b135fb 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/aggregations/AggregatorTestCase.java @@ -89,7 +89,6 @@ import org.elasticsearch.index.mapper.FieldAliasMapper; import org.elasticsearch.index.mapper.FieldMapper; import org.elasticsearch.index.mapper.GeoPointFieldMapper; -import org.elasticsearch.index.mapper.GeoShapeFieldMapper; import org.elasticsearch.index.mapper.IdLoader; import org.elasticsearch.index.mapper.KeywordFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; @@ -198,7 +197,6 @@ public abstract class AggregatorTestCase extends ESTestCase { // A list of field types that should not be tested, or are not currently supported private static final List TYPE_TEST_BLACKLIST = List.of( ObjectMapper.CONTENT_TYPE, // Cannot aggregate objects - GeoShapeFieldMapper.CONTENT_TYPE, // Cannot aggregate geoshapes (yet) DenseVectorFieldMapper.CONTENT_TYPE, // Cannot aggregate dense vectors SparseVectorFieldMapper.CONTENT_TYPE, // Sparse vectors are no longer supported @@ -644,7 +642,8 @@ private A searchAndReduce( bigArraysForReduction, getMockScriptService(), () -> false, - builder + builder, + b -> {} ); AggregatorCollectorManager aggregatorCollectorManager = new AggregatorCollectorManager( aggregatorSupplier, @@ -669,7 +668,8 @@ private A searchAndReduce( bigArraysForReduction, getMockScriptService(), () -> false, - builder + builder, + b -> {} ); internalAggs = new ArrayList<>(internalAggs.subList(r, toReduceSize)); internalAggs.add(InternalAggregations.topLevelReduce(toReduce, reduceContext)); diff --git a/test/framework/src/main/java/org/elasticsearch/search/geo/BasePointShapeQueryTestCase.java b/test/framework/src/main/java/org/elasticsearch/search/geo/BasePointShapeQueryTestCase.java index 52d2f3f53a43e..6c84a9ba601cf 100644 --- a/test/framework/src/main/java/org/elasticsearch/search/geo/BasePointShapeQueryTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/search/geo/BasePointShapeQueryTestCase.java @@ -32,16 +32,12 @@ import org.elasticsearch.geometry.ShapeType; import org.elasticsearch.geometry.utils.WellKnownText; import org.elasticsearch.index.query.AbstractGeometryQueryBuilder; -import org.elasticsearch.plugins.Plugin; import org.elasticsearch.search.SearchHits; import org.elasticsearch.test.ESSingleNodeTestCase; -import org.elasticsearch.test.TestGeoShapeFieldMapperPlugin; import org.elasticsearch.xcontent.XContentFactory; import org.elasticsearch.xcontent.XContentType; import org.hamcrest.CoreMatchers; -import java.util.Collection; -import java.util.Collections; import java.util.List; import java.util.Map; @@ -67,11 +63,6 @@ public abstract class BasePointShapeQueryTestCase> getPlugins() { - return Collections.singleton(TestGeoShapeFieldMapperPlugin.class); - } - protected abstract void createMapping(String indexName, String fieldName, Settings settings) throws Exception; protected void createMapping(String indexName, String fieldName) throws Exception { diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESSingleNodeTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESSingleNodeTestCase.java index 8526acc851c72..7fdc5765a90e8 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESSingleNodeTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESSingleNodeTestCase.java @@ -69,6 +69,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.elasticsearch.action.search.SearchTransportService.FREE_CONTEXT_ACTION_NAME; import static org.elasticsearch.cluster.coordination.ClusterBootstrapService.INITIAL_MASTER_NODES_SETTING; import static org.elasticsearch.discovery.SettingsBasedSeedHostsProvider.DISCOVERY_SEED_HOSTS_SETTING; import static org.elasticsearch.test.NodeRoles.dataNode; @@ -130,6 +131,8 @@ public void tearDown() throws Exception { logger.trace("[{}#{}]: cleaning up after test", getTestClass().getSimpleName(), getTestName()); awaitIndexShardCloseAsyncTasks(); ensureNoInitializingShards(); + ensureAllFreeContextActionsAreConsumed(); + SearchService searchService = getInstanceFromNode(SearchService.class); assertThat(searchService.getActiveContexts(), equalTo(0)); assertThat(searchService.getOpenScrollContexts(), equalTo(0)); @@ -455,6 +458,14 @@ protected void ensureNoInitializingShards() { assertFalse("timed out waiting for shards to initialize", actionGet.isTimedOut()); } + /** + * waits until all free_context actions have been handled by the generic thread pool + */ + protected void ensureAllFreeContextActionsAreConsumed() throws Exception { + logger.info("--> waiting for all free_context tasks to complete within a reasonable time"); + safeGet(clusterAdmin().prepareListTasks().setActions(FREE_CONTEXT_ACTION_NAME + "*").setWaitForCompletion(true).execute()); + } + /** * Whether we'd like to enable inter-segment search concurrency and increase the likelihood of leveraging it, by creating multiple * slices with a low amount of documents in them, which would not be allowed in production. diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index add0de1993233..7295dce7a257a 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -38,6 +38,7 @@ import org.apache.lucene.tests.util.TestUtil; import org.apache.lucene.tests.util.TimeUnits; import org.apache.lucene.util.SetOnce; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.TransportVersion; import org.elasticsearch.action.ActionFuture; import org.elasticsearch.action.RequestBuilder; @@ -179,12 +180,14 @@ import java.util.concurrent.Executor; import java.util.concurrent.ExecutorService; import java.util.concurrent.Future; +import java.util.concurrent.FutureTask; import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.BooleanSupplier; import java.util.function.Consumer; +import java.util.function.IntConsumer; import java.util.function.IntFunction; import java.util.function.Predicate; import java.util.function.Supplier; @@ -999,6 +1002,13 @@ public static int randomNonNegativeInt() { return randomInt() & Integer.MAX_VALUE; } + /** + * @return an int between Integer.MIN_VALUE and -1 (inclusive) chosen uniformly at random. + */ + public static int randomNegativeInt() { + return randomInt() | Integer.MIN_VALUE; + } + public static float randomFloat() { return random().nextFloat(); } @@ -2423,4 +2433,53 @@ public static T expectThrows(Class expectedType, Reques () -> builder.get().decRef() // dec ref if we unexpectedly fail to not leak transport response ); } + + /** + * Same as {@link #runInParallel(int, IntConsumer)} but also attempts to start all tasks at the same time by blocking execution on a + * barrier until all threads are started and ready to execute their task. + */ + public static void startInParallel(int numberOfTasks, IntConsumer taskFactory) throws InterruptedException { + final CyclicBarrier barrier = new CyclicBarrier(numberOfTasks); + runInParallel(numberOfTasks, i -> { + safeAwait(barrier); + taskFactory.accept(i); + }); + } + + /** + * Run {@code numberOfTasks} parallel tasks that were created by the given {@code taskFactory}. On of the tasks will be run on the + * calling thread, the rest will be run on a new thread. + * @param numberOfTasks number of tasks to run in parallel + * @param taskFactory task factory + */ + public static void runInParallel(int numberOfTasks, IntConsumer taskFactory) throws InterruptedException { + final ArrayList> futures = new ArrayList<>(numberOfTasks); + final Thread[] threads = new Thread[numberOfTasks - 1]; + for (int i = 0; i < numberOfTasks; i++) { + final int index = i; + var future = new FutureTask(() -> taskFactory.accept(index), null); + futures.add(future); + if (i == numberOfTasks - 1) { + future.run(); + } else { + threads[i] = new Thread(future); + threads[i].setName("runInParallel-T#" + i); + threads[i].start(); + } + } + for (Thread thread : threads) { + thread.join(); + } + Exception e = null; + for (Future future : futures) { + try { + future.get(); + } catch (Exception ex) { + e = ExceptionsHelper.useOrSuppress(e, ex); + } + } + if (e != null) { + throw new AssertionError(e); + } + } } diff --git a/test/framework/src/main/java/org/elasticsearch/test/InternalAggregationTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/InternalAggregationTestCase.java index 12c5085cbcd73..4aed7ff4565cb 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/InternalAggregationTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/InternalAggregationTestCase.java @@ -77,7 +77,7 @@ public static AggregationReduceContext.Builder emptyReduceContextBuilder(Aggrega return new AggregationReduceContext.Builder() { @Override public AggregationReduceContext forPartialReduction() { - return new AggregationReduceContext.ForPartial(BigArrays.NON_RECYCLING_INSTANCE, null, () -> false, aggs); + return new AggregationReduceContext.ForPartial(BigArrays.NON_RECYCLING_INSTANCE, null, () -> false, aggs, b -> {}); } @Override @@ -95,7 +95,7 @@ public static AggregationReduceContext.Builder mockReduceContext(AggregationBuil return new AggregationReduceContext.Builder() { @Override public AggregationReduceContext forPartialReduction() { - return new AggregationReduceContext.ForPartial(BigArrays.NON_RECYCLING_INSTANCE, null, () -> false, agg); + return new AggregationReduceContext.ForPartial(BigArrays.NON_RECYCLING_INSTANCE, null, () -> false, agg, b -> {}); } @Override @@ -244,7 +244,8 @@ public void testReduceRandom() throws IOException { bigArrays, mockScriptService, () -> false, - inputs.builder() + inputs.builder(), + b -> {} ); @SuppressWarnings("unchecked") T reduced = (T) reduce(toPartialReduce, context); diff --git a/test/framework/src/main/java/org/elasticsearch/test/InternalTestCluster.java b/test/framework/src/main/java/org/elasticsearch/test/InternalTestCluster.java index bb78c43fca449..af37fb6feefbd 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/InternalTestCluster.java +++ b/test/framework/src/main/java/org/elasticsearch/test/InternalTestCluster.java @@ -61,8 +61,6 @@ import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.PageCacheRecycler; import org.elasticsearch.common.util.concurrent.EsExecutors; -import org.elasticsearch.common.util.concurrent.FutureUtils; -import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Predicates; @@ -126,8 +124,6 @@ import java.util.TreeMap; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutionException; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Future; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -148,6 +144,7 @@ import static org.elasticsearch.node.Node.INITIAL_STATE_TIMEOUT_SETTING; import static org.elasticsearch.test.ESTestCase.assertBusy; import static org.elasticsearch.test.ESTestCase.randomFrom; +import static org.elasticsearch.test.ESTestCase.runInParallel; import static org.elasticsearch.test.ESTestCase.safeAwait; import static org.elasticsearch.test.NodeRoles.dataOnlyNode; import static org.elasticsearch.test.NodeRoles.masterOnlyNode; @@ -246,8 +243,6 @@ public String toString() { private final NodeConfigurationSource nodeConfigurationSource; - private final ExecutorService executor; - private final boolean autoManageMasterNodes; private final Collection> mockPlugins; @@ -452,16 +447,6 @@ public InternalTestCluster( builder.put(NoMasterBlockService.NO_MASTER_BLOCK_SETTING.getKey(), randomFrom(random, "write", "metadata_write")); builder.put(DestructiveOperations.REQUIRES_NAME_SETTING.getKey(), false); defaultSettings = builder.build(); - executor = EsExecutors.newScaling( - "internal_test_cluster_executor", - 0, - Integer.MAX_VALUE, - 0, - TimeUnit.SECONDS, - true, - EsExecutors.daemonThreadFactory("test_" + clusterName), - new ThreadContext(Settings.EMPTY) - ); } /** @@ -931,7 +916,6 @@ public synchronized void close() throws IOException { } finally { nodes = Collections.emptyNavigableMap(); Loggers.setLevel(nodeConnectionLogger, initialLogLevel); - executor.shutdownNow(); } } } @@ -1760,18 +1744,10 @@ private synchronized void startAndPublishNodesAndClients(List nod .filter(nac -> nodes.containsKey(nac.name) == false) // filter out old masters .count(); rebuildUnicastHostFiles(nodeAndClients); // ensure that new nodes can find the existing nodes when they start - List> futures = nodeAndClients.stream().map(node -> executor.submit(node::startNode)).collect(Collectors.toList()); - try { - for (Future future : futures) { - future.get(); - } + runInParallel(nodeAndClients.size(), i -> nodeAndClients.get(i).startNode()); } catch (InterruptedException e) { throw new AssertionError("interrupted while starting nodes", e); - } catch (ExecutionException e) { - RuntimeException re = FutureUtils.rethrowExecutionException(e); - re.addSuppressed(new RuntimeException("failed to start nodes")); - throw re; } nodeAndClients.forEach(this::publishNode); diff --git a/test/framework/src/main/java/org/elasticsearch/test/TestGeoShapeFieldMapperPlugin.java b/test/framework/src/main/java/org/elasticsearch/test/TestGeoShapeFieldMapperPlugin.java deleted file mode 100644 index cd373432992d2..0000000000000 --- a/test/framework/src/main/java/org/elasticsearch/test/TestGeoShapeFieldMapperPlugin.java +++ /dev/null @@ -1,34 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0 and the Server Side Public License, v 1; you may not use this file except - * in compliance with, at your election, the Elastic License 2.0 or the Server - * Side Public License, v 1. - */ -package org.elasticsearch.test; - -import org.elasticsearch.index.mapper.GeoShapeFieldMapper; -import org.elasticsearch.index.mapper.Mapper; -import org.elasticsearch.plugins.MapperPlugin; -import org.elasticsearch.plugins.Plugin; - -import java.util.Collections; -import java.util.LinkedHashMap; -import java.util.Map; - -/** - * Some tests depend on the {@link org.elasticsearch.index.mapper.GeoShapeFieldMapper}. - * This mapper is registered in the spatial-extras module, but used in many integration - * tests in server code. The goal is to migrate all of the spatial/geo pieces to the spatial-extras - * module such that no tests in server depend on this test plugin - */ -@Deprecated -public class TestGeoShapeFieldMapperPlugin extends Plugin implements MapperPlugin { - - @Override - public Map getMappers() { - Map mappers = new LinkedHashMap<>(); - mappers.put(GeoShapeFieldMapper.CONTENT_TYPE, GeoShapeFieldMapper.PARSER); - return Collections.unmodifiableMap(mappers); - } -} diff --git a/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestFeatureService.java b/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestFeatureService.java index 78a4126ec09db..92d72afbf9d52 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestFeatureService.java +++ b/test/framework/src/main/java/org/elasticsearch/test/rest/ESRestTestFeatureService.java @@ -86,19 +86,6 @@ public boolean clusterHasFeature(String featureId) { Matcher matcher = VERSION_FEATURE_PATTERN.matcher(featureId); if (matcher.matches()) { Version extractedVersion = Version.fromString(matcher.group(1)); - if (Version.V_8_15_0.before(extractedVersion)) { - // As of version 8.14.0 REST tests have been migrated to use features only. - // For migration purposes we provide a synthetic version feature gte_vX.Y.Z for any version at or before 8.15.0 - // allowing for some transition period. - throw new IllegalArgumentException( - Strings.format( - "Synthetic version features are only available before [%s] for migration purposes! " - + "Please add a cluster feature to an appropriate FeatureSpecification; test-only historical-features " - + "can be supplied via ESRestTestCase#additionalTestOnlyHistoricalFeatures()", - Version.V_8_15_0 - ) - ); - } return version.onOrAfter(extractedVersion); } diff --git a/test/framework/src/test/java/org/elasticsearch/test/test/ESTestCaseTests.java b/test/framework/src/test/java/org/elasticsearch/test/test/ESTestCaseTests.java index 125c0563577fc..714c9bcde0469 100644 --- a/test/framework/src/test/java/org/elasticsearch/test/test/ESTestCaseTests.java +++ b/test/framework/src/test/java/org/elasticsearch/test/test/ESTestCaseTests.java @@ -45,6 +45,7 @@ import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.not; import static org.hamcrest.Matchers.notNullValue; @@ -185,6 +186,10 @@ public void testRandomNonNegativeInt() { assertThat(randomNonNegativeInt(), greaterThanOrEqualTo(0)); } + public void testRandomNegativeInt() { + assertThat(randomNegativeInt(), lessThan(0)); + } + public void testRandomValueOtherThan() { // "normal" way of calling where the value is not null int bad = randomInt(); diff --git a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java index 49fb38b518dce..a8a33da27aebe 100644 --- a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java +++ b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/FeatureFlag.java @@ -16,7 +16,12 @@ */ public enum FeatureFlag { TIME_SERIES_MODE("es.index_mode_feature_flag_registered=true", Version.fromString("8.0.0"), null), - FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null); + FAILURE_STORE_ENABLED("es.failure_store_feature_flag_enabled=true", Version.fromString("8.12.0"), null), + INFERENCE_ADAPTIVE_ALLOCATIONS_ENABLED( + "es.inference_adaptive_allocations_feature_flag_enabled=true", + Version.fromString("8.16.0"), + null + ); public final String systemProperty; public final Version from; diff --git a/x-pack/plugin/analytics/src/test/java/org/elasticsearch/xpack/analytics/aggregations/metrics/HistogramPercentileAggregationTests.java b/x-pack/plugin/analytics/src/test/java/org/elasticsearch/xpack/analytics/aggregations/metrics/HistogramPercentileAggregationTests.java index f60466bcf43cc..7c6f85104b5f8 100644 --- a/x-pack/plugin/analytics/src/test/java/org/elasticsearch/xpack/analytics/aggregations/metrics/HistogramPercentileAggregationTests.java +++ b/x-pack/plugin/analytics/src/test/java/org/elasticsearch/xpack/analytics/aggregations/metrics/HistogramPercentileAggregationTests.java @@ -241,7 +241,6 @@ public void testTDigestHistogram() throws Exception { ); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/110406") public void testBoxplotHistogram() throws Exception { int compression = TestUtil.nextInt(random(), 200, 300); setupTDigestHistogram(compression); diff --git a/x-pack/plugin/core/build.gradle b/x-pack/plugin/core/build.gradle index 0c65c7e4b6d29..1ed59d6fe3581 100644 --- a/x-pack/plugin/core/build.gradle +++ b/x-pack/plugin/core/build.gradle @@ -51,7 +51,6 @@ dependencies { // security deps api 'com.unboundid:unboundid-ldapsdk:6.0.3' - api "com.nimbusds:nimbus-jose-jwt:9.23" implementation project(":x-pack:plugin:core:template-resources") @@ -135,27 +134,7 @@ tasks.named("thirdPartyAudit").configure { //commons-logging provided dependencies 'javax.servlet.ServletContextEvent', 'javax.servlet.ServletContextListener', - 'javax.jms.Message', - // Optional dependency of nimbus-jose-jwt for handling Ed25519 signatures and ECDH with X25519 (RFC 8037) - 'com.google.crypto.tink.subtle.Ed25519Sign', - 'com.google.crypto.tink.subtle.Ed25519Sign$KeyPair', - 'com.google.crypto.tink.subtle.Ed25519Verify', - 'com.google.crypto.tink.subtle.X25519', - 'com.google.crypto.tink.subtle.XChaCha20Poly1305', - // optional dependencies for nimbus-jose-jwt - 'org.bouncycastle.asn1.pkcs.PrivateKeyInfo', - 'org.bouncycastle.asn1.x509.AlgorithmIdentifier', - 'org.bouncycastle.asn1.x509.SubjectPublicKeyInfo', - 'org.bouncycastle.cert.X509CertificateHolder', - 'org.bouncycastle.cert.jcajce.JcaX509CertificateHolder', - 'org.bouncycastle.crypto.InvalidCipherTextException', - 'org.bouncycastle.crypto.engines.AESEngine', - 'org.bouncycastle.crypto.modes.GCMBlockCipher', - 'org.bouncycastle.jcajce.provider.BouncyCastleFipsProvider', - 'org.bouncycastle.jce.provider.BouncyCastleProvider', - 'org.bouncycastle.openssl.PEMKeyPair', - 'org.bouncycastle.openssl.PEMParser', - 'org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter' + 'javax.jms.Message' ) } diff --git a/x-pack/plugin/core/src/main/java/module-info.java b/x-pack/plugin/core/src/main/java/module-info.java index 282072417875b..72436bb9d5171 100644 --- a/x-pack/plugin/core/src/main/java/module-info.java +++ b/x-pack/plugin/core/src/main/java/module-info.java @@ -22,7 +22,6 @@ requires unboundid.ldapsdk; requires org.elasticsearch.tdigest; requires org.elasticsearch.xcore.templates; - requires com.nimbusds.jose.jwt; exports org.elasticsearch.index.engine.frozen; exports org.elasticsearch.license; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java index dfb77ccd49fc2..e9d612751e48f 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/DeleteInferenceEndpointAction.java @@ -11,8 +11,10 @@ import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.master.AcknowledgedRequest; import org.elasticsearch.action.support.master.AcknowledgedResponse; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.TaskType; import org.elasticsearch.xcontent.XContentBuilder; @@ -105,10 +107,16 @@ public static class Response extends AcknowledgedResponse { private final String PIPELINE_IDS = "pipelines"; Set pipelineIds; + private final String REFERENCED_INDEXES = "indexes"; + Set indexes; + private final String DRY_RUN_MESSAGE = "error_message"; // error message only returned in response for dry_run + String dryRunMessage; - public Response(boolean acknowledged, Set pipelineIds) { + public Response(boolean acknowledged, Set pipelineIds, Set semanticTextIndexes, @Nullable String dryRunMessage) { super(acknowledged); this.pipelineIds = pipelineIds; + this.indexes = semanticTextIndexes; + this.dryRunMessage = dryRunMessage; } public Response(StreamInput in) throws IOException { @@ -118,6 +126,15 @@ public Response(StreamInput in) throws IOException { } else { pipelineIds = Set.of(); } + + if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS)) { + indexes = in.readCollectionAsSet(StreamInput::readString); + dryRunMessage = in.readOptionalString(); + } else { + indexes = Set.of(); + dryRunMessage = null; + } + } @Override @@ -126,23 +143,25 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_ENHANCE_DELETE_ENDPOINT)) { out.writeCollection(pipelineIds, StreamOutput::writeString); } + if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_DONT_DELETE_WHEN_SEMANTIC_TEXT_EXISTS)) { + out.writeCollection(indexes, StreamOutput::writeString); + out.writeOptionalString(dryRunMessage); + } } @Override protected void addCustomFields(XContentBuilder builder, Params params) throws IOException { super.addCustomFields(builder, params); builder.field(PIPELINE_IDS, pipelineIds); + builder.field(REFERENCED_INDEXES, indexes); + if (dryRunMessage != null) { + builder.field(DRY_RUN_MESSAGE, dryRunMessage); + } } @Override public String toString() { - StringBuilder returnable = new StringBuilder(); - returnable.append("acknowledged: ").append(this.acknowledged); - returnable.append(", pipelineIdsByEndpoint: "); - for (String entry : pipelineIds) { - returnable.append(entry).append(", "); - } - return returnable.toString(); + return Strings.toString(this); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentAction.java index 9b383b2652af4..c6976ab4b513e 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentAction.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; @@ -18,6 +19,7 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -34,15 +36,22 @@ private CreateTrainedModelAssignmentAction() { public static class Request extends MasterNodeRequest { private final StartTrainedModelDeploymentAction.TaskParams taskParams; + private final AdaptiveAllocationsSettings adaptiveAllocationsSettings; - public Request(StartTrainedModelDeploymentAction.TaskParams taskParams) { + public Request(StartTrainedModelDeploymentAction.TaskParams taskParams, AdaptiveAllocationsSettings adaptiveAllocationsSettings) { super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT); this.taskParams = ExceptionsHelper.requireNonNull(taskParams, "taskParams"); + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; } public Request(StreamInput in) throws IOException { super(in); this.taskParams = new StartTrainedModelDeploymentAction.TaskParams(in); + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + this.adaptiveAllocationsSettings = in.readOptionalWriteable(AdaptiveAllocationsSettings::new); + } else { + this.adaptiveAllocationsSettings = null; + } } @Override @@ -54,6 +63,9 @@ public ActionRequestValidationException validate() { public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); taskParams.writeTo(out); + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + out.writeOptionalWriteable(adaptiveAllocationsSettings); + } } @Override @@ -61,17 +73,22 @@ public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; Request request = (Request) o; - return Objects.equals(taskParams, request.taskParams); + return Objects.equals(taskParams, request.taskParams) + && Objects.equals(adaptiveAllocationsSettings, request.adaptiveAllocationsSettings); } @Override public int hashCode() { - return Objects.hash(taskParams); + return Objects.hash(taskParams, adaptiveAllocationsSettings); } public StartTrainedModelDeploymentAction.TaskParams getTaskParams() { return taskParams; } + + public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() { + return adaptiveAllocationsSettings; + } } public static class Response extends ActionResponse implements ToXContentObject { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java index ca9b86a90f875..59eaf4affa9a8 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentAction.java @@ -29,8 +29,11 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.ml.MlConfigVersion; import org.elasticsearch.xpack.core.ml.inference.TrainedModelConfig; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsFeatureFlag; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.inference.assignment.AllocationStatus; import org.elasticsearch.xpack.core.ml.inference.assignment.Priority; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import org.elasticsearch.xpack.core.ml.utils.MlTaskParams; @@ -40,7 +43,6 @@ import java.util.Optional; import java.util.concurrent.TimeUnit; -import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; import static org.elasticsearch.xpack.core.ml.MlTasks.trainedModelAssignmentTaskDescription; public class StartTrainedModelDeploymentAction extends ActionType { @@ -99,6 +101,7 @@ public static class Request extends MasterNodeRequest implements ToXCon public static final ParseField QUEUE_CAPACITY = TaskParams.QUEUE_CAPACITY; public static final ParseField CACHE_SIZE = TaskParams.CACHE_SIZE; public static final ParseField PRIORITY = TaskParams.PRIORITY; + public static final ParseField ADAPTIVE_ALLOCATIONS = TrainedModelAssignment.ADAPTIVE_ALLOCATIONS; public static final ObjectParser PARSER = new ObjectParser<>(NAME, Request::new); @@ -117,6 +120,14 @@ public static class Request extends MasterNodeRequest implements ToXCon ObjectParser.ValueType.VALUE ); PARSER.declareString(Request::setPriority, PRIORITY); + if (AdaptiveAllocationsFeatureFlag.isEnabled()) { + PARSER.declareObjectOrNull( + Request::setAdaptiveAllocationsSettings, + (p, c) -> AdaptiveAllocationsSettings.PARSER.parse(p, c).build(), + null, + ADAPTIVE_ALLOCATIONS + ); + } } public static Request parseRequest(String modelId, String deploymentId, XContentParser parser) { @@ -140,7 +151,8 @@ public static Request parseRequest(String modelId, String deploymentId, XContent private TimeValue timeout = DEFAULT_TIMEOUT; private AllocationStatus.State waitForState = DEFAULT_WAITFOR_STATE; private ByteSizeValue cacheSize; - private int numberOfAllocations = DEFAULT_NUM_ALLOCATIONS; + private Integer numberOfAllocations; + private AdaptiveAllocationsSettings adaptiveAllocationsSettings = null; private int threadsPerAllocation = DEFAULT_NUM_THREADS; private int queueCapacity = DEFAULT_QUEUE_CAPACITY; private Priority priority = DEFAULT_PRIORITY; @@ -160,7 +172,11 @@ public Request(StreamInput in) throws IOException { modelId = in.readString(); timeout = in.readTimeValue(); waitForState = in.readEnum(AllocationStatus.State.class); - numberOfAllocations = in.readVInt(); + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + numberOfAllocations = in.readOptionalVInt(); + } else { + numberOfAllocations = in.readVInt(); + } threadsPerAllocation = in.readVInt(); queueCapacity = in.readVInt(); if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_4_0)) { @@ -171,12 +187,16 @@ public Request(StreamInput in) throws IOException { } else { this.priority = Priority.NORMAL; } - if (in.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) { this.deploymentId = in.readString(); } else { this.deploymentId = modelId; } + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + this.adaptiveAllocationsSettings = in.readOptionalWriteable(AdaptiveAllocationsSettings::new); + } else { + this.adaptiveAllocationsSettings = null; + } } public final void setModelId(String modelId) { @@ -212,14 +232,34 @@ public Request setWaitForState(AllocationStatus.State waitForState) { return this; } - public int getNumberOfAllocations() { + public Integer getNumberOfAllocations() { return numberOfAllocations; } - public void setNumberOfAllocations(int numberOfAllocations) { + public int computeNumberOfAllocations() { + if (numberOfAllocations != null) { + return numberOfAllocations; + } else { + if (adaptiveAllocationsSettings == null || adaptiveAllocationsSettings.getMinNumberOfAllocations() == null) { + return DEFAULT_NUM_ALLOCATIONS; + } else { + return adaptiveAllocationsSettings.getMinNumberOfAllocations(); + } + } + } + + public void setNumberOfAllocations(Integer numberOfAllocations) { this.numberOfAllocations = numberOfAllocations; } + public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() { + return adaptiveAllocationsSettings; + } + + public void setAdaptiveAllocationsSettings(AdaptiveAllocationsSettings adaptiveAllocationsSettings) { + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; + } + public int getThreadsPerAllocation() { return threadsPerAllocation; } @@ -258,7 +298,11 @@ public void writeTo(StreamOutput out) throws IOException { out.writeString(modelId); out.writeTimeValue(timeout); out.writeEnum(waitForState); - out.writeVInt(numberOfAllocations); + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + out.writeOptionalVInt(numberOfAllocations); + } else { + out.writeVInt(numberOfAllocations); + } out.writeVInt(threadsPerAllocation); out.writeVInt(queueCapacity); if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_4_0)) { @@ -270,6 +314,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) { out.writeString(deploymentId); } + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + out.writeOptionalWriteable(adaptiveAllocationsSettings); + } } @Override @@ -279,7 +326,12 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws builder.field(DEPLOYMENT_ID.getPreferredName(), deploymentId); builder.field(TIMEOUT.getPreferredName(), timeout.getStringRep()); builder.field(WAIT_FOR.getPreferredName(), waitForState); - builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations); + if (numberOfAllocations != null) { + builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations); + } + if (adaptiveAllocationsSettings != null) { + builder.field(ADAPTIVE_ALLOCATIONS.getPreferredName(), adaptiveAllocationsSettings); + } builder.field(THREADS_PER_ALLOCATION.getPreferredName(), threadsPerAllocation); builder.field(QUEUE_CAPACITY.getPreferredName(), queueCapacity); if (cacheSize != null) { @@ -301,12 +353,25 @@ public ActionRequestValidationException validate() { + Strings.arrayToCommaDelimitedString(VALID_WAIT_STATES) ); } - if (numberOfAllocations < 1) { - validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] must be a positive integer"); + if (numberOfAllocations != null) { + if (numberOfAllocations < 1) { + validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] must be a positive integer"); + } + if (adaptiveAllocationsSettings != null && adaptiveAllocationsSettings.getEnabled()) { + validationException.addValidationError( + "[" + NUMBER_OF_ALLOCATIONS + "] cannot be set if adaptive allocations is enabled" + ); + } } if (threadsPerAllocation < 1) { validationException.addValidationError("[" + THREADS_PER_ALLOCATION + "] must be a positive integer"); } + ActionRequestValidationException autoscaleException = adaptiveAllocationsSettings == null + ? null + : adaptiveAllocationsSettings.validate(); + if (autoscaleException != null) { + validationException.addValidationErrors(autoscaleException.validationErrors()); + } if (threadsPerAllocation > MAX_THREADS_PER_ALLOCATION || isPowerOf2(threadsPerAllocation) == false) { validationException.addValidationError( "[" + THREADS_PER_ALLOCATION + "] must be a power of 2 less than or equal to " + MAX_THREADS_PER_ALLOCATION @@ -322,7 +387,7 @@ public ActionRequestValidationException validate() { validationException.addValidationError("[" + TIMEOUT + "] must be positive"); } if (priority == Priority.LOW) { - if (numberOfAllocations > 1) { + if (numberOfAllocations != null && numberOfAllocations > 1) { validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] must be 1 when [" + PRIORITY + "] is low"); } if (threadsPerAllocation > 1) { @@ -344,6 +409,7 @@ public int hashCode() { timeout, waitForState, numberOfAllocations, + adaptiveAllocationsSettings, threadsPerAllocation, queueCapacity, cacheSize, @@ -365,7 +431,8 @@ public boolean equals(Object obj) { && Objects.equals(timeout, other.timeout) && Objects.equals(waitForState, other.waitForState) && Objects.equals(cacheSize, other.cacheSize) - && numberOfAllocations == other.numberOfAllocations + && Objects.equals(numberOfAllocations, other.numberOfAllocations) + && Objects.equals(adaptiveAllocationsSettings, other.adaptiveAllocationsSettings) && threadsPerAllocation == other.threadsPerAllocation && queueCapacity == other.queueCapacity && priority == other.priority; @@ -430,7 +497,7 @@ public static boolean mayAssignToNode(@Nullable DiscoveryNode node) { PARSER.declareInt(ConstructingObjectParser.optionalConstructorArg(), THREADS_PER_ALLOCATION); PARSER.declareInt(ConstructingObjectParser.constructorArg(), QUEUE_CAPACITY); PARSER.declareField( - optionalConstructorArg(), + ConstructingObjectParser.optionalConstructorArg(), (p, c) -> ByteSizeValue.parseBytesSizeValue(p.text(), CACHE_SIZE.getPreferredName()), CACHE_SIZE, ObjectParser.ValueType.VALUE diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java index 62a7d84c60a62..28152bc0d5556 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/action/UpdateTrainedModelDeploymentAction.java @@ -7,6 +7,7 @@ package org.elasticsearch.xpack.core.ml.action; +import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionType; import org.elasticsearch.action.support.master.AcknowledgedRequest; @@ -19,12 +20,15 @@ import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsFeatureFlag; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.job.messages.Messages; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; import java.io.IOException; import java.util.Objects; +import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.ADAPTIVE_ALLOCATIONS; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.MODEL_ID; import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.NUMBER_OF_ALLOCATIONS; @@ -46,6 +50,14 @@ public static class Request extends AcknowledgedRequest implements ToXC static { PARSER.declareString(Request::setDeploymentId, MODEL_ID); PARSER.declareInt(Request::setNumberOfAllocations, NUMBER_OF_ALLOCATIONS); + if (AdaptiveAllocationsFeatureFlag.isEnabled()) { + PARSER.declareObjectOrNull( + Request::setAdaptiveAllocationsSettings, + (p, c) -> AdaptiveAllocationsSettings.PARSER.parse(p, c).build(), + AdaptiveAllocationsSettings.RESET_PLACEHOLDER, + ADAPTIVE_ALLOCATIONS + ); + } PARSER.declareString((r, val) -> r.ackTimeout(TimeValue.parseTimeValue(val, TIMEOUT.getPreferredName())), TIMEOUT); } @@ -62,7 +74,9 @@ public static Request parseRequest(String deploymentId, XContentParser parser) { } private String deploymentId; - private int numberOfAllocations; + private Integer numberOfAllocations; + private AdaptiveAllocationsSettings adaptiveAllocationsSettings; + private boolean isInternal; private Request() { super(TRAPPY_IMPLICIT_DEFAULT_MASTER_NODE_TIMEOUT, DEFAULT_ACK_TIMEOUT); @@ -76,7 +90,15 @@ public Request(String deploymentId) { public Request(StreamInput in) throws IOException { super(in); deploymentId = in.readString(); - numberOfAllocations = in.readVInt(); + if (in.getTransportVersion().before(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + numberOfAllocations = in.readVInt(); + adaptiveAllocationsSettings = null; + isInternal = false; + } else { + numberOfAllocations = in.readOptionalVInt(); + adaptiveAllocationsSettings = in.readOptionalWriteable(AdaptiveAllocationsSettings::new); + isInternal = in.readBoolean(); + } } public final void setDeploymentId(String deploymentId) { @@ -87,26 +109,53 @@ public String getDeploymentId() { return deploymentId; } - public void setNumberOfAllocations(int numberOfAllocations) { + public void setNumberOfAllocations(Integer numberOfAllocations) { this.numberOfAllocations = numberOfAllocations; } - public int getNumberOfAllocations() { + public Integer getNumberOfAllocations() { return numberOfAllocations; } + public void setAdaptiveAllocationsSettings(AdaptiveAllocationsSettings adaptiveAllocationsSettings) { + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; + } + + public boolean isInternal() { + return isInternal; + } + + public void setIsInternal(boolean isInternal) { + this.isInternal = isInternal; + } + + public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() { + return adaptiveAllocationsSettings; + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); out.writeString(deploymentId); - out.writeVInt(numberOfAllocations); + if (out.getTransportVersion().before(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + out.writeVInt(numberOfAllocations); + } else { + out.writeOptionalVInt(numberOfAllocations); + out.writeOptionalWriteable(adaptiveAllocationsSettings); + out.writeBoolean(isInternal); + } } @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); builder.field(MODEL_ID.getPreferredName(), deploymentId); - builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations); + if (numberOfAllocations != null) { + builder.field(NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations); + } + if (adaptiveAllocationsSettings != null) { + builder.field(ADAPTIVE_ALLOCATIONS.getPreferredName(), adaptiveAllocationsSettings); + } builder.endObject(); return builder; } @@ -114,15 +163,28 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = new ActionRequestValidationException(); - if (numberOfAllocations < 1) { - validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] must be a positive integer"); + if (numberOfAllocations != null) { + if (numberOfAllocations < 1) { + validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] must be a positive integer"); + } + if (isInternal == false && adaptiveAllocationsSettings != null && adaptiveAllocationsSettings.getEnabled()) { + validationException.addValidationError( + "[" + NUMBER_OF_ALLOCATIONS + "] cannot be set if adaptive allocations is enabled" + ); + } + } + ActionRequestValidationException autoscaleException = adaptiveAllocationsSettings == null + ? null + : adaptiveAllocationsSettings.validate(); + if (autoscaleException != null) { + validationException.addValidationErrors(autoscaleException.validationErrors()); } return validationException.validationErrors().isEmpty() ? null : validationException; } @Override public int hashCode() { - return Objects.hash(deploymentId, numberOfAllocations); + return Objects.hash(deploymentId, numberOfAllocations, adaptiveAllocationsSettings, isInternal); } @Override @@ -134,7 +196,10 @@ public boolean equals(Object obj) { return false; } Request other = (Request) obj; - return Objects.equals(deploymentId, other.deploymentId) && numberOfAllocations == other.numberOfAllocations; + return Objects.equals(deploymentId, other.deploymentId) + && Objects.equals(numberOfAllocations, other.numberOfAllocations) + && Objects.equals(adaptiveAllocationsSettings, other.adaptiveAllocationsSettings) + && isInternal == other.isInternal; } @Override diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsFeatureFlag.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsFeatureFlag.java new file mode 100644 index 0000000000000..a3b508c0534f9 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsFeatureFlag.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.assignment; + +import org.elasticsearch.common.util.FeatureFlag; + +/** + * semantic_text feature flag. When the feature is complete, this flag will be removed. + */ +public class AdaptiveAllocationsFeatureFlag { + + private AdaptiveAllocationsFeatureFlag() {} + + private static final FeatureFlag FEATURE_FLAG = new FeatureFlag("inference_adaptive_allocations"); + + public static boolean isEnabled() { + return FEATURE_FLAG.isEnabled(); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java new file mode 100644 index 0000000000000..0b5a62ccb588c --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AdaptiveAllocationsSettings.java @@ -0,0 +1,181 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.ml.inference.assignment; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class AdaptiveAllocationsSettings implements ToXContentObject, Writeable { + + public static final AdaptiveAllocationsSettings RESET_PLACEHOLDER = new AdaptiveAllocationsSettings(false, -1, -1); + + public static final ParseField ENABLED = new ParseField("enabled"); + public static final ParseField MIN_NUMBER_OF_ALLOCATIONS = new ParseField("min_number_of_allocations"); + public static final ParseField MAX_NUMBER_OF_ALLOCATIONS = new ParseField("max_number_of_allocations"); + + public static final ObjectParser PARSER = new ObjectParser<>( + "autoscaling_settings", + AdaptiveAllocationsSettings.Builder::new + ); + + static { + PARSER.declareBoolean(Builder::setEnabled, ENABLED); + PARSER.declareIntOrNull(Builder::setMinNumberOfAllocations, -1, MIN_NUMBER_OF_ALLOCATIONS); + PARSER.declareIntOrNull(Builder::setMaxNumberOfAllocations, -1, MAX_NUMBER_OF_ALLOCATIONS); + } + + public static AdaptiveAllocationsSettings parseRequest(XContentParser parser) { + return PARSER.apply(parser, null).build(); + } + + public static class Builder { + private Boolean enabled; + private Integer minNumberOfAllocations; + private Integer maxNumberOfAllocations; + + public Builder() {} + + public Builder(AdaptiveAllocationsSettings settings) { + enabled = settings.enabled; + minNumberOfAllocations = settings.minNumberOfAllocations; + maxNumberOfAllocations = settings.maxNumberOfAllocations; + } + + public void setEnabled(Boolean enabled) { + this.enabled = enabled; + } + + public void setMinNumberOfAllocations(Integer minNumberOfAllocations) { + this.minNumberOfAllocations = minNumberOfAllocations; + } + + public void setMaxNumberOfAllocations(Integer maxNumberOfAllocations) { + this.maxNumberOfAllocations = maxNumberOfAllocations; + } + + public AdaptiveAllocationsSettings build() { + return new AdaptiveAllocationsSettings(enabled, minNumberOfAllocations, maxNumberOfAllocations); + } + } + + private final Boolean enabled; + private final Integer minNumberOfAllocations; + private final Integer maxNumberOfAllocations; + + public AdaptiveAllocationsSettings(Boolean enabled, Integer minNumberOfAllocations, Integer maxNumberOfAllocations) { + this.enabled = enabled; + this.minNumberOfAllocations = minNumberOfAllocations; + this.maxNumberOfAllocations = maxNumberOfAllocations; + } + + public AdaptiveAllocationsSettings(StreamInput in) throws IOException { + enabled = in.readOptionalBoolean(); + minNumberOfAllocations = in.readOptionalInt(); + maxNumberOfAllocations = in.readOptionalInt(); + } + + public Boolean getEnabled() { + return enabled; + } + + public Integer getMinNumberOfAllocations() { + return minNumberOfAllocations; + } + + public Integer getMaxNumberOfAllocations() { + return maxNumberOfAllocations; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + if (enabled != null) { + builder.field(ENABLED.getPreferredName(), enabled); + } + if (minNumberOfAllocations != null) { + builder.field(MIN_NUMBER_OF_ALLOCATIONS.getPreferredName(), minNumberOfAllocations); + } + if (maxNumberOfAllocations != null) { + builder.field(MAX_NUMBER_OF_ALLOCATIONS.getPreferredName(), maxNumberOfAllocations); + } + builder.endObject(); + return builder; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalBoolean(enabled); + out.writeOptionalInt(minNumberOfAllocations); + out.writeOptionalInt(maxNumberOfAllocations); + } + + public AdaptiveAllocationsSettings merge(AdaptiveAllocationsSettings updates) { + AdaptiveAllocationsSettings.Builder builder = new Builder(this); + if (updates.getEnabled() != null) { + builder.setEnabled(updates.enabled); + } + if (updates.minNumberOfAllocations != null) { + if (updates.minNumberOfAllocations == -1) { + builder.setMinNumberOfAllocations(null); + } else { + builder.setMinNumberOfAllocations(updates.minNumberOfAllocations); + } + } + if (updates.maxNumberOfAllocations != null) { + if (updates.maxNumberOfAllocations == -1) { + builder.setMaxNumberOfAllocations(null); + } else { + builder.setMaxNumberOfAllocations(updates.maxNumberOfAllocations); + } + } + return builder.build(); + } + + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = new ActionRequestValidationException(); + boolean hasMinNumberOfAllocations = (minNumberOfAllocations != null && minNumberOfAllocations != -1); + if (hasMinNumberOfAllocations && minNumberOfAllocations < 1) { + validationException.addValidationError("[" + MIN_NUMBER_OF_ALLOCATIONS + "] must be a positive integer or null"); + } + boolean hasMaxNumberOfAllocations = (maxNumberOfAllocations != null && maxNumberOfAllocations != -1); + if (hasMaxNumberOfAllocations && maxNumberOfAllocations < 1) { + validationException.addValidationError("[" + MAX_NUMBER_OF_ALLOCATIONS + "] must be a positive integer or null"); + } + if (hasMinNumberOfAllocations && hasMaxNumberOfAllocations && minNumberOfAllocations > maxNumberOfAllocations) { + validationException.addValidationError( + "[" + MIN_NUMBER_OF_ALLOCATIONS + "] must not be larger than [" + MAX_NUMBER_OF_ALLOCATIONS + "]" + ); + } + return validationException.validationErrors().isEmpty() ? null : validationException; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AdaptiveAllocationsSettings that = (AdaptiveAllocationsSettings) o; + return Objects.equals(enabled, that.enabled) + && Objects.equals(minNumberOfAllocations, that.minNumberOfAllocations) + && Objects.equals(maxNumberOfAllocations, that.maxNumberOfAllocations); + } + + @Override + public int hashCode() { + return Objects.hash(enabled, minNumberOfAllocations, maxNumberOfAllocations); + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java index d8e5d7a6d9603..aadaa5254ff15 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStats.java @@ -423,6 +423,8 @@ public int hashCode() { @Nullable private final Integer numberOfAllocations; @Nullable + private final AdaptiveAllocationsSettings adaptiveAllocationsSettings; + @Nullable private final Integer queueCapacity; @Nullable private final ByteSizeValue cacheSize; @@ -435,6 +437,7 @@ public AssignmentStats( String modelId, @Nullable Integer threadsPerAllocation, @Nullable Integer numberOfAllocations, + @Nullable AdaptiveAllocationsSettings adaptiveAllocationsSettings, @Nullable Integer queueCapacity, @Nullable ByteSizeValue cacheSize, Instant startTime, @@ -445,6 +448,7 @@ public AssignmentStats( this.modelId = modelId; this.threadsPerAllocation = threadsPerAllocation; this.numberOfAllocations = numberOfAllocations; + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; this.queueCapacity = queueCapacity; this.startTime = Objects.requireNonNull(startTime); this.nodeStats = nodeStats; @@ -479,6 +483,11 @@ public AssignmentStats(StreamInput in) throws IOException { } else { deploymentId = modelId; } + if (in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + adaptiveAllocationsSettings = in.readOptionalWriteable(AdaptiveAllocationsSettings::new); + } else { + adaptiveAllocationsSettings = null; + } } public String getDeploymentId() { @@ -499,6 +508,11 @@ public Integer getNumberOfAllocations() { return numberOfAllocations; } + @Nullable + public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() { + return adaptiveAllocationsSettings; + } + @Nullable public Integer getQueueCapacity() { return queueCapacity; @@ -575,6 +589,9 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws if (numberOfAllocations != null) { builder.field(StartTrainedModelDeploymentAction.TaskParams.NUMBER_OF_ALLOCATIONS.getPreferredName(), numberOfAllocations); } + if (adaptiveAllocationsSettings != null) { + builder.field(StartTrainedModelDeploymentAction.Request.ADAPTIVE_ALLOCATIONS.getPreferredName(), adaptiveAllocationsSettings); + } if (queueCapacity != null) { builder.field(StartTrainedModelDeploymentAction.TaskParams.QUEUE_CAPACITY.getPreferredName(), queueCapacity); } @@ -649,6 +666,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_8_0)) { out.writeString(deploymentId); } + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + out.writeOptionalWriteable(adaptiveAllocationsSettings); + } } @Override @@ -660,6 +680,7 @@ public boolean equals(Object o) { && Objects.equals(modelId, that.modelId) && Objects.equals(threadsPerAllocation, that.threadsPerAllocation) && Objects.equals(numberOfAllocations, that.numberOfAllocations) + && Objects.equals(adaptiveAllocationsSettings, that.adaptiveAllocationsSettings) && Objects.equals(queueCapacity, that.queueCapacity) && Objects.equals(startTime, that.startTime) && Objects.equals(state, that.state) @@ -677,6 +698,7 @@ public int hashCode() { modelId, threadsPerAllocation, numberOfAllocations, + adaptiveAllocationsSettings, queueCapacity, startTime, nodeStats, diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java index b7219fbaa2061..60e0c0e86a828 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignment.java @@ -23,6 +23,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xpack.core.common.time.TimeUtils; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.utils.ExceptionsHelper; @@ -52,6 +53,7 @@ public final class TrainedModelAssignment implements SimpleDiffable PARSER = new ConstructingObjectParser<>( @@ -64,7 +66,8 @@ public final class TrainedModelAssignment implements SimpleDiffable AdaptiveAllocationsSettings.PARSER.parse(p, c).build(), + null, + ADAPTIVE_ALLOCATIONS + ); } private final StartTrainedModelDeploymentAction.TaskParams taskParams; @@ -96,6 +105,7 @@ public final class TrainedModelAssignment implements SimpleDiffable assignableNodeIds) { int allocations = nodeRoutingTable.entrySet() .stream() @@ -301,12 +324,21 @@ public boolean equals(Object o) { && Objects.equals(reason, that.reason) && Objects.equals(assignmentState, that.assignmentState) && Objects.equals(startTime, that.startTime) - && maxAssignedAllocations == that.maxAssignedAllocations; + && maxAssignedAllocations == that.maxAssignedAllocations + && Objects.equals(adaptiveAllocationsSettings, that.adaptiveAllocationsSettings); } @Override public int hashCode() { - return Objects.hash(nodeRoutingTable, taskParams, assignmentState, reason, startTime, maxAssignedAllocations); + return Objects.hash( + nodeRoutingTable, + taskParams, + assignmentState, + reason, + startTime, + maxAssignedAllocations, + adaptiveAllocationsSettings + ); } @Override @@ -320,6 +352,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } builder.timeField(START_TIME.getPreferredName(), startTime); builder.field(MAX_ASSIGNED_ALLOCATIONS.getPreferredName(), maxAssignedAllocations); + builder.field(ADAPTIVE_ALLOCATIONS.getPreferredName(), adaptiveAllocationsSettings); builder.endObject(); return builder; } @@ -334,6 +367,9 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_4_0)) { out.writeVInt(maxAssignedAllocations); } + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + out.writeOptionalWriteable(adaptiveAllocationsSettings); + } } public Optional calculateAllocationStatus() { @@ -355,6 +391,7 @@ public static class Builder { private String reason; private Instant startTime; private int maxAssignedAllocations; + private AdaptiveAllocationsSettings adaptiveAllocationsSettings; public static Builder fromAssignment(TrainedModelAssignment assignment) { return new Builder( @@ -363,12 +400,20 @@ public static Builder fromAssignment(TrainedModelAssignment assignment) { assignment.assignmentState, assignment.reason, assignment.startTime, - assignment.maxAssignedAllocations + assignment.maxAssignedAllocations, + assignment.adaptiveAllocationsSettings ); } - public static Builder empty(StartTrainedModelDeploymentAction.TaskParams taskParams) { - return new Builder(taskParams); + public static Builder empty(CreateTrainedModelAssignmentAction.Request request) { + return new Builder(request.getTaskParams(), request.getAdaptiveAllocationsSettings()); + } + + public static Builder empty( + StartTrainedModelDeploymentAction.TaskParams taskParams, + AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) { + return new Builder(taskParams, adaptiveAllocationsSettings); } private Builder( @@ -377,7 +422,8 @@ private Builder( AssignmentState assignmentState, String reason, Instant startTime, - int maxAssignedAllocations + int maxAssignedAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettings ) { this.taskParams = taskParams; this.nodeRoutingTable = new LinkedHashMap<>(nodeRoutingTable); @@ -385,10 +431,11 @@ private Builder( this.reason = reason; this.startTime = startTime; this.maxAssignedAllocations = maxAssignedAllocations; + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; } - private Builder(StartTrainedModelDeploymentAction.TaskParams taskParams) { - this(taskParams, new LinkedHashMap<>(), AssignmentState.STARTING, null, Instant.now(), 0); + private Builder(StartTrainedModelDeploymentAction.TaskParams taskParams, AdaptiveAllocationsSettings adaptiveAllocationsSettings) { + this(taskParams, new LinkedHashMap<>(), AssignmentState.STARTING, null, Instant.now(), 0, adaptiveAllocationsSettings); } public Builder setStartTime(Instant startTime) { @@ -401,6 +448,11 @@ public Builder setMaxAssignedAllocations(int maxAssignedAllocations) { return this; } + public Builder setAdaptiveAllocationsSettings(AdaptiveAllocationsSettings adaptiveAllocationsSettings) { + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; + return this; + } + public Builder addRoutingEntry(String nodeId, RoutingInfo routingInfo) { if (nodeRoutingTable.containsKey(nodeId)) { throw new ResourceAlreadyExistsException( @@ -518,7 +570,15 @@ public Builder setNumberOfAllocations(int numberOfAllocations) { } public TrainedModelAssignment build() { - return new TrainedModelAssignment(taskParams, nodeRoutingTable, assignmentState, reason, startTime, maxAssignedAllocations); + return new TrainedModelAssignment( + taskParams, + nodeRoutingTable, + assignmentState, + reason, + startTime, + maxAssignedAllocations, + adaptiveAllocationsSettings + ); } } } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java new file mode 100644 index 0000000000000..544c1e344c91f --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/ml/utils/SemanticTextInfoExtractor.java @@ -0,0 +1,50 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a Generative AI + */ + +package org.elasticsearch.xpack.core.ml.utils; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.InferenceFieldMetadata; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.transport.Transports; + +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +public class SemanticTextInfoExtractor { + private static final Logger logger = LogManager.getLogger(SemanticTextInfoExtractor.class); + + public static Set extractIndexesReferencingInferenceEndpoints(Metadata metadata, Set endpointIds) { + assert Transports.assertNotTransportThread("non-trivial nested loops over cluster state structures"); + assert endpointIds.isEmpty() == false; + assert metadata != null; + + Set referenceIndices = new HashSet<>(); + + Map indices = metadata.indices(); + + indices.forEach((indexName, indexMetadata) -> { + if (indexMetadata.getInferenceFields() != null) { + Map inferenceFields = indexMetadata.getInferenceFields(); + if (inferenceFields.entrySet() + .stream() + .anyMatch( + entry -> entry.getValue().getInferenceId() != null && endpointIds.contains(entry.getValue().getInferenceId()) + )) { + referenceIndices.add(indexName); + } + } + }); + + return referenceIndices; + } +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/Grant.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/Grant.java index b186ab45a7dc7..c98564251cd43 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/Grant.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/Grant.java @@ -7,19 +7,13 @@ package org.elasticsearch.xpack.core.security.action; -import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.SecureString; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.support.BearerToken; -import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken; import java.io.IOException; @@ -136,30 +130,6 @@ public void setClientAuthentication(ClientAuthentication clientAuthentication) { this.clientAuthentication = clientAuthentication; } - public AuthenticationToken getAuthenticationToken() { - assert validate(null) == null : "grant is invalid"; - return switch (type) { - case PASSWORD_GRANT_TYPE -> new UsernamePasswordToken(username, password); - case ACCESS_TOKEN_GRANT_TYPE -> { - SecureString clientAuthentication = this.clientAuthentication != null ? this.clientAuthentication.value() : null; - AuthenticationToken token = JwtAuthenticationToken.tryParseJwt(accessToken, clientAuthentication); - if (token != null) { - yield token; - } - if (clientAuthentication != null) { - clientAuthentication.close(); - throw new ElasticsearchSecurityException( - "[client_authentication] not supported with the supplied access_token type", - RestStatus.BAD_REQUEST - ); - } - // here we effectively assume it's an ES access token (from the {@code TokenService}) - yield new BearerToken(accessToken); - } - default -> throw new ElasticsearchSecurityException("the grant type [{}] is not supported", type); - }; - } - public ActionRequestValidationException validate(ActionRequestValidationException validationException) { if (type == null) { validationException = addValidationError("[grant_type] is required", validationException); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/role/BulkPutRoleRequestBuilder.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/role/BulkPutRoleRequestBuilder.java index ba199e183d4af..cda45a67e81c6 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/role/BulkPutRoleRequestBuilder.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/role/BulkPutRoleRequestBuilder.java @@ -44,7 +44,7 @@ public class BulkPutRoleRequestBuilder extends ActionRequestBuilder roles; - public BulkPutRolesRequest() {} + public BulkPutRolesRequest(List roles) { + this.roles = roles; + } public void setRoles(List roles) { this.roles = roles; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/role/QueryRoleResponse.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/role/QueryRoleResponse.java index 6bdc6c66c1835..8e9da10e449ad 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/role/QueryRoleResponse.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/action/role/QueryRoleResponse.java @@ -86,7 +86,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws // other details of the role descriptor (in the same object). assert Strings.isNullOrEmpty(roleDescriptor.getName()) == false; builder.field("name", roleDescriptor.getName()); - roleDescriptor.innerToXContent(builder, params, false, false); + roleDescriptor.innerToXContent(builder, params, false); if (sortValues != null && sortValues.length > 0) { builder.array("_sort", sortValues); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/RoleDescriptor.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/RoleDescriptor.java index 7bedab61bd43d..1a8839fa0fa4a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/RoleDescriptor.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authz/RoleDescriptor.java @@ -417,13 +417,8 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } public XContentBuilder toXContent(XContentBuilder builder, Params params, boolean docCreation) throws IOException { - return toXContent(builder, params, docCreation, false); - } - - public XContentBuilder toXContent(XContentBuilder builder, Params params, boolean docCreation, boolean includeMetadataFlattened) - throws IOException { builder.startObject(); - innerToXContent(builder, params, docCreation, includeMetadataFlattened); + innerToXContent(builder, params, docCreation); return builder.endObject(); } @@ -435,12 +430,10 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params, boolea * @param docCreation {@code true} if the x-content is being generated for creating a document * in the security index, {@code false} if the x-content being generated * is for API display purposes - * @param includeMetadataFlattened {@code true} if the metadataFlattened field should be included in doc * @return x-content builder * @throws IOException if there was an error writing the x-content to the builder */ - public XContentBuilder innerToXContent(XContentBuilder builder, Params params, boolean docCreation, boolean includeMetadataFlattened) - throws IOException { + public XContentBuilder innerToXContent(XContentBuilder builder, Params params, boolean docCreation) throws IOException { builder.array(Fields.CLUSTER.getPreferredName(), clusterPrivileges); if (configurableClusterPrivileges.length != 0) { builder.field(Fields.GLOBAL.getPreferredName()); @@ -452,9 +445,7 @@ public XContentBuilder innerToXContent(XContentBuilder builder, Params params, b builder.array(Fields.RUN_AS.getPreferredName(), runAs); } builder.field(Fields.METADATA.getPreferredName(), metadata); - if (includeMetadataFlattened) { - builder.field(Fields.METADATA_FLATTENED.getPreferredName(), metadata); - } + if (docCreation) { builder.field(Fields.TYPE.getPreferredName(), ROLE_TYPE); } else { @@ -1196,7 +1187,7 @@ private static ApplicationResourcePrivileges parseApplicationPrivilege(String ro public static final class RemoteIndicesPrivileges implements Writeable, ToXContentObject { - private static final RemoteIndicesPrivileges[] NONE = new RemoteIndicesPrivileges[0]; + public static final RemoteIndicesPrivileges[] NONE = new RemoteIndicesPrivileges[0]; private final IndicesPrivileges indicesPrivileges; private final String[] remoteClusters; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/transform/transforms/pivot/GeoTileGroupSource.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/transform/transforms/pivot/GeoTileGroupSource.java index 68109f429f461..6b4394f1c2b52 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/transform/transforms/pivot/GeoTileGroupSource.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/transform/transforms/pivot/GeoTileGroupSource.java @@ -10,7 +10,6 @@ import org.elasticsearch.common.geo.GeoBoundingBox; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.index.mapper.GeoShapeFieldMapper; import org.elasticsearch.search.aggregations.bucket.geogrid.GeoTileUtils; import org.elasticsearch.xcontent.ConstructingObjectParser; import org.elasticsearch.xcontent.ObjectParser; @@ -138,7 +137,7 @@ public int hashCode() { @Override public String getMappingType() { - return GeoShapeFieldMapper.CONTENT_TYPE; + return "geo_shape"; } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentActionRequestTests.java index 71a68a65b7977..39f646df0d582 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/CreateTrainedModelAssignmentActionRequestTests.java @@ -14,7 +14,7 @@ public class CreateTrainedModelAssignmentActionRequestTests extends AbstractWire @Override protected Request createTestInstance() { - return new Request(StartTrainedModelDeploymentTaskParamsTests.createRandom()); + return new Request(StartTrainedModelDeploymentTaskParamsTests.createRandom(), null); } @Override diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java index 8c175c17fccc8..d60bbc6cc7713 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/GetTrainedModelsStatsActionResponseTests.java @@ -156,6 +156,7 @@ protected Response mutateInstanceForVersion(Response instance, TransportVersion stats.getDeploymentStats().getModelId(), stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getNumberOfAllocations(), + null, stats.getDeploymentStats().getQueueCapacity(), null, stats.getDeploymentStats().getStartTime(), @@ -228,6 +229,7 @@ protected Response mutateInstanceForVersion(Response instance, TransportVersion stats.getDeploymentStats().getModelId(), stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getNumberOfAllocations(), + null, stats.getDeploymentStats().getQueueCapacity(), null, stats.getDeploymentStats().getStartTime(), @@ -300,6 +302,7 @@ protected Response mutateInstanceForVersion(Response instance, TransportVersion stats.getDeploymentStats().getModelId(), stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getNumberOfAllocations(), + null, stats.getDeploymentStats().getQueueCapacity(), null, stats.getDeploymentStats().getStartTime(), @@ -372,6 +375,7 @@ protected Response mutateInstanceForVersion(Response instance, TransportVersion stats.getDeploymentStats().getModelId(), stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getNumberOfAllocations(), + null, stats.getDeploymentStats().getQueueCapacity(), stats.getDeploymentStats().getCacheSize(), stats.getDeploymentStats().getStartTime(), @@ -445,6 +449,7 @@ protected Response mutateInstanceForVersion(Response instance, TransportVersion stats.getDeploymentStats().getModelId(), stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getNumberOfAllocations(), + null, stats.getDeploymentStats().getQueueCapacity(), stats.getDeploymentStats().getCacheSize(), stats.getDeploymentStats().getStartTime(), @@ -518,6 +523,7 @@ protected Response mutateInstanceForVersion(Response instance, TransportVersion stats.getDeploymentStats().getModelId(), stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getNumberOfAllocations(), + null, stats.getDeploymentStats().getQueueCapacity(), stats.getDeploymentStats().getCacheSize(), stats.getDeploymentStats().getStartTime(), @@ -591,6 +597,7 @@ protected Response mutateInstanceForVersion(Response instance, TransportVersion stats.getDeploymentStats().getModelId(), stats.getDeploymentStats().getThreadsPerAllocation(), stats.getDeploymentStats().getNumberOfAllocations(), + null, stats.getDeploymentStats().getQueueCapacity(), stats.getDeploymentStats().getCacheSize(), stats.getDeploymentStats().getStartTime(), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java index ad33a85d42e53..730d994fc5e35 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/action/StartTrainedModelDeploymentRequestTests.java @@ -71,7 +71,8 @@ public static Request createRandom() { } if (randomBoolean()) { request.setPriority(randomFrom(Priority.values()).toString()); - if (request.getNumberOfAllocations() > 1 || request.getThreadsPerAllocation() > 1) { + if ((request.getNumberOfAllocations() != null && request.getNumberOfAllocations() > 1) + || request.getThreadsPerAllocation() > 1) { request.setPriority(Priority.NORMAL.toString()); } } @@ -230,7 +231,8 @@ public void testDefaults() { Request request = new Request(randomAlphaOfLength(10), randomAlphaOfLength(10)); assertThat(request.getTimeout(), equalTo(TimeValue.timeValueSeconds(30))); assertThat(request.getWaitForState(), equalTo(AllocationStatus.State.STARTED)); - assertThat(request.getNumberOfAllocations(), equalTo(1)); + assertThat(request.getNumberOfAllocations(), nullValue()); + assertThat(request.computeNumberOfAllocations(), equalTo(1)); assertThat(request.getThreadsPerAllocation(), equalTo(1)); assertThat(request.getQueueCapacity(), equalTo(1024)); } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java index a1ab023a6935f..07c56b073cd00 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/AssignmentStatsTests.java @@ -50,6 +50,7 @@ public static AssignmentStats randomDeploymentStats() { modelId, randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8), + null, randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 10000000)), Instant.now(), @@ -102,6 +103,7 @@ public void testGetOverallInferenceStats() { modelId, randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8), + null, randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)), Instant.now(), @@ -166,6 +168,7 @@ public void testGetOverallInferenceStatsWithNoNodes() { modelId, randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8), + null, randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)), Instant.now(), @@ -187,6 +190,7 @@ public void testGetOverallInferenceStatsWithOnlyStoppedNodes() { modelId, randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8), + null, randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)), Instant.now(), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java index 75706f3d6a9bf..6d70105dfedba 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/ml/inference/assignment/TrainedModelAssignmentTests.java @@ -39,7 +39,7 @@ public class TrainedModelAssignmentTests extends AbstractXContentSerializingTestCase { public static TrainedModelAssignment randomInstance() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomParams()); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomParams(), null); List nodes = Stream.generate(() -> randomAlphaOfLength(10)).limit(randomInt(5)).toList(); for (String node : nodes) { builder.addRoutingEntry(node, RoutingInfoTests.randomInstance()); @@ -72,7 +72,7 @@ protected TrainedModelAssignment mutateInstance(TrainedModelAssignment instance) } public void testBuilderAddingExistingRoute() { - TrainedModelAssignment.Builder assignment = TrainedModelAssignment.Builder.empty(randomParams()); + TrainedModelAssignment.Builder assignment = TrainedModelAssignment.Builder.empty(randomParams(), null); String addingNode = "new-node"; assignment.addRoutingEntry(addingNode, RoutingInfoTests.randomInstance()); @@ -80,7 +80,7 @@ public void testBuilderAddingExistingRoute() { } public void testBuilderUpdatingMissingRoute() { - TrainedModelAssignment.Builder assignment = TrainedModelAssignment.Builder.empty(randomParams()); + TrainedModelAssignment.Builder assignment = TrainedModelAssignment.Builder.empty(randomParams(), null); String addingNode = "new-node"; expectThrows( ResourceNotFoundException.class, @@ -93,7 +93,7 @@ public void testGetStartedNodes() { String startedNode2 = "started-node-2"; String nodeInAnotherState1 = "another-state-node-1"; String nodeInAnotherState2 = "another-state-node-2"; - TrainedModelAssignment allocation = TrainedModelAssignment.Builder.empty(randomParams()) + TrainedModelAssignment allocation = TrainedModelAssignment.Builder.empty(randomParams(), null) .addRoutingEntry(startedNode1, RoutingInfoTests.randomInstance(RoutingState.STARTED)) .addRoutingEntry(startedNode2, RoutingInfoTests.randomInstance(RoutingState.STARTED)) .addRoutingEntry( @@ -114,20 +114,20 @@ public void testGetStartedNodes() { public void testCalculateAllocationStatus_GivenNoAllocations() { assertThat( - TrainedModelAssignment.Builder.empty(randomTaskParams(5)).build().calculateAllocationStatus(), + TrainedModelAssignment.Builder.empty(randomTaskParams(5), null).build().calculateAllocationStatus(), isPresentWith(new AllocationStatus(0, 5)) ); } public void testCalculateAllocationStatus_GivenStoppingAssignment() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(1, 2, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(2, 1, RoutingState.STARTED, "")); assertThat(builder.stopAssignment("test").build().calculateAllocationStatus(), isEmpty()); } public void testCalculateAllocationStatus_GivenPartiallyAllocated() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(1, 2, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(2, 1, RoutingState.STARTED, "")); builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTING, "")); @@ -135,28 +135,28 @@ public void testCalculateAllocationStatus_GivenPartiallyAllocated() { } public void testCalculateAllocationStatus_GivenFullyAllocated() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")); assertThat(builder.build().calculateAllocationStatus(), isPresentWith(new AllocationStatus(5, 5))); } public void testCalculateAssignmentState_GivenNoStartedAssignments() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTING, "")); builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTING, "")); assertThat(builder.calculateAssignmentState(), equalTo(AssignmentState.STARTING)); } public void testCalculateAssignmentState_GivenOneStartedAssignment() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTING, "")); builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")); assertThat(builder.calculateAssignmentState(), equalTo(AssignmentState.STARTED)); } public void testCalculateAndSetAssignmentState_GivenStoppingAssignment() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")); assertThat( @@ -166,7 +166,7 @@ public void testCalculateAndSetAssignmentState_GivenStoppingAssignment() { } public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenNoStartedAllocations() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTING, "")); builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STOPPED, "")); TrainedModelAssignment assignment = builder.build(); @@ -175,7 +175,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenNoS } public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenSingleStartedNode() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, "")); TrainedModelAssignment assignment = builder.build(); @@ -185,7 +185,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenSin } public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenAShuttingDownRoute_ItReturnsNoNodes() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STARTED, "")); TrainedModelAssignment assignment = builder.build(); @@ -195,7 +195,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenASh } public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenAShuttingDownRoute_ItReturnsNode1() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(4, 4, RoutingState.STOPPING, "")); TrainedModelAssignment assignment = builder.build(); @@ -205,7 +205,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenASh } public void testSingleRequestWith2Nodes() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(5), null); builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")); TrainedModelAssignment assignment = builder.build(); @@ -216,7 +216,7 @@ public void testSingleRequestWith2Nodes() { } public void testSelectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMultipleStartedNodes() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6), null); builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(2, 2, RoutingState.STARTED, "")); builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTED, "")); @@ -239,7 +239,7 @@ public void testSelectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMul } public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMultipleStartedNodesWithZeroAllocations() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6), null); builder.addRoutingEntry("node-1", new RoutingInfo(0, 0, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(0, 0, RoutingState.STARTED, "")); builder.addRoutingEntry("node-3", new RoutingInfo(0, 0, RoutingState.STARTED, "")); @@ -257,7 +257,7 @@ public void testselectRandomStartedNodeWeighedOnAllocationsForNRequests_GivenMul } public void testIsSatisfied_GivenEnoughAllocations() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6), null); builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(2, 2, RoutingState.STARTED, "")); builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTED, "")); @@ -266,7 +266,7 @@ public void testIsSatisfied_GivenEnoughAllocations() { } public void testIsSatisfied_GivenEnoughAllocations_ButOneNodeIsNotAssignable() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6), null); builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(2, 2, RoutingState.STARTED, "")); builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTED, "")); @@ -275,7 +275,7 @@ public void testIsSatisfied_GivenEnoughAllocations_ButOneNodeIsNotAssignable() { } public void testIsSatisfied_GivenEnoughAllocations_ButOneNodeIsNeitherStartingNorStarted() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(6), null); builder.addRoutingEntry( "node-1", new RoutingInfo(1, 1, randomFrom(RoutingState.FAILED, RoutingState.STOPPING, RoutingState.STOPPED), "") @@ -287,7 +287,7 @@ public void testIsSatisfied_GivenEnoughAllocations_ButOneNodeIsNeitherStartingNo } public void testIsSatisfied_GivenNotEnoughAllocations() { - TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(7)); + TrainedModelAssignment.Builder builder = TrainedModelAssignment.Builder.empty(randomTaskParams(7), null); builder.addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")); builder.addRoutingEntry("node-2", new RoutingInfo(2, 2, RoutingState.STARTED, "")); builder.addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTED, "")); @@ -296,7 +296,7 @@ public void testIsSatisfied_GivenNotEnoughAllocations() { } public void testMaxAssignedAllocations() { - TrainedModelAssignment assignment = TrainedModelAssignment.Builder.empty(randomTaskParams(10)) + TrainedModelAssignment assignment = TrainedModelAssignment.Builder.empty(randomTaskParams(10), null) .addRoutingEntry("node-1", new RoutingInfo(1, 2, RoutingState.STARTED, "")) .addRoutingEntry("node-2", new RoutingInfo(2, 1, RoutingState.STARTED, "")) .addRoutingEntry("node-3", new RoutingInfo(3, 3, RoutingState.STARTING, "")) diff --git a/x-pack/plugin/core/template-resources/src/main/resources/metrics@settings.json b/x-pack/plugin/core/template-resources/src/main/resources/metrics@settings.json index 4f3fac1aed5ae..9960bd2e7fdac 100644 --- a/x-pack/plugin/core/template-resources/src/main/resources/metrics@settings.json +++ b/x-pack/plugin/core/template-resources/src/main/resources/metrics@settings.json @@ -10,9 +10,6 @@ "total_fields": { "ignore_dynamic_beyond_limit": true } - }, - "query": { - "default_field": ["message"] } } } diff --git a/x-pack/plugin/core/template-resources/src/main/resources/metrics@tsdb-settings.json b/x-pack/plugin/core/template-resources/src/main/resources/metrics@tsdb-settings.json index b0db168e8189d..cb0e2cbffb50b 100644 --- a/x-pack/plugin/core/template-resources/src/main/resources/metrics@tsdb-settings.json +++ b/x-pack/plugin/core/template-resources/src/main/resources/metrics@tsdb-settings.json @@ -9,9 +9,6 @@ "total_fields": { "ignore_dynamic_beyond_limit": true } - }, - "query": { - "default_field": ["message"] } } } diff --git a/x-pack/plugin/core/template-resources/src/main/resources/monitoring-beats-mb.json b/x-pack/plugin/core/template-resources/src/main/resources/monitoring-beats-mb.json index fab8ca451358f..7457dce805eca 100644 --- a/x-pack/plugin/core/template-resources/src/main/resources/monitoring-beats-mb.json +++ b/x-pack/plugin/core/template-resources/src/main/resources/monitoring-beats-mb.json @@ -1,5 +1,7 @@ { - "index_patterns": [".monitoring-beats-${xpack.stack.monitoring.template.version}-*"], + "index_patterns": [ + ".monitoring-beats-${xpack.stack.monitoring.template.version}-*" + ], "version": ${xpack.stack.monitoring.template.release.version}, "template": { "mappings": { @@ -198,6 +200,9 @@ "ratelimit": { "type": "long" }, + "timeout": { + "type": "long" + }, "toolarge": { "type": "long" }, @@ -212,16 +217,6 @@ } } }, - "request": { - "properties": { - "count": { - "type": "long" - } - } - }, - "unset": { - "type": "long" - }, "valid": { "properties": { "accepted": { @@ -239,151 +234,436 @@ } } } + }, + "unset": { + "type": "long" } } }, - "decoder": { + "agentcfg": { "properties": { - "deflate": { - "properties": { - "content-length": { - "type": "long" - }, - "count": { - "type": "long" - } - } - }, - "gzip": { - "properties": { - "content-length": { - "type": "long" - }, - "count": { - "type": "long" - } - } - }, - "missing-content-length": { + "elasticsearch": { "properties": { - "count": { - "type": "long" - } - } - }, - "reader": { - "properties": { - "count": { - "type": "long" - }, - "size": { - "type": "long" - } - } - }, - "uncompressed": { - "properties": { - "content-length": { - "type": "long" + "cache": { + "properties": { + "entries": { + "properties": { + "count": { + "type": "long" + } + } + }, + "refresh": { + "properties": { + "failures": { + "type": "long" + }, + "successes": { + "type": "long" + } + } + } + } }, - "count": { - "type": "long" + "fetch": { + "properties": { + "es": { + "type": "long" + }, + "fallback": { + "type": "long" + }, + "invalid": { + "type": "long" + }, + "unavailable": { + "type": "long" + } + } } } } } }, - "processor": { + "jaeger": { "properties": { - "error": { + "grpc": { "properties": { - "decoding": { + "collect": { "properties": { - "count": { - "type": "long" + "request": { + "properties": { + "count": { + "type": "long" + } + } }, - "errors": { - "type": "long" + "response": { + "properties": { + "count": { + "type": "long" + }, + "errors": { + "properties": { + "count": { + "type": "long" + }, + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "unauthorized": { + "type": "long" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "long" + } + } + } + } } } }, - "frames": { - "type": "long" - }, - "spans": { - "type": "long" - }, - "stacktraces": { - "type": "long" - }, - "transformations": { - "type": "long" - }, - "validation": { + "sampling": { "properties": { - "count": { - "type": "long" + "event": { + "properties": { + "received": { + "properties": { + "count": { + "type": "long" + } + } + } + } }, - "errors": { - "type": "long" + "request": { + "properties": { + "count": { + "type": "long" + } + } + }, + "response": { + "properties": { + "count": { + "type": "long" + }, + "errors": { + "properties": { + "count": { + "type": "long" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "long" + } + } + } + } } } } } - }, - "metric": { + } + } + }, + "otlp": { + "properties": { + "grpc": { "properties": { - "decoding": { + "logs": { "properties": { - "count": { - "type": "long" + "request": { + "properties": { + "count": { + "type": "long" + } + } }, - "errors": { - "type": "long" + "response": { + "properties": { + "count": { + "type": "long" + }, + "errors": { + "properties": { + "count": { + "type": "long" + }, + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "unauthorized": { + "type": "long" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "long" + } + } + } + } } } }, - "transformations": { - "type": "long" + "metrics": { + "properties": { + "consumer": { + "properties": { + "unsupported_dropped": { + "type": "long" + } + } + }, + "request": { + "properties": { + "count": { + "type": "long" + } + } + }, + "response": { + "properties": { + "count": { + "type": "long" + }, + "errors": { + "properties": { + "count": { + "type": "long" + }, + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "unauthorized": { + "type": "long" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "long" + } + } + } + } + } + } }, - "validation": { + "traces": { "properties": { - "count": { - "type": "long" + "request": { + "properties": { + "count": { + "type": "long" + } + } }, - "errors": { - "type": "long" + "response": { + "properties": { + "count": { + "type": "long" + }, + "errors": { + "properties": { + "count": { + "type": "long" + }, + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "unauthorized": { + "type": "long" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "long" + } + } + } + } } } } } }, - "sourcemap": { + "http": { "properties": { - "counter": { - "type": "long" + "logs": { + "properties": { + "request": { + "properties": { + "count": { + "type": "long" + } + } + }, + "response": { + "properties": { + "count": { + "type": "long" + }, + "errors": { + "properties": { + "count": { + "type": "long" + }, + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "unauthorized": { + "type": "long" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "long" + } + } + } + } + } + } }, - "decoding": { + "metrics": { "properties": { - "count": { - "type": "long" + "consumer": { + "properties": { + "unsupported_dropped": { + "type": "long" + } + } }, - "errors": { - "type": "long" + "request": { + "properties": { + "count": { + "type": "long" + } + } + }, + "response": { + "properties": { + "count": { + "type": "long" + }, + "errors": { + "properties": { + "count": { + "type": "long" + }, + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "unauthorized": { + "type": "long" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "long" + } + } + } + } } } }, - "validation": { + "traces": { "properties": { - "count": { - "type": "long" + "request": { + "properties": { + "count": { + "type": "long" + } + } }, - "errors": { - "type": "long" + "response": { + "properties": { + "count": { + "type": "long" + }, + "errors": { + "properties": { + "count": { + "type": "long" + }, + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "unauthorized": { + "type": "long" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "long" + } + } + } + } } } } } + } + } + }, + "processor": { + "properties": { + "error": { + "properties": { + "transformations": { + "type": "long" + } + } + }, + "metric": { + "properties": { + "transformations": { + "type": "long" + } + } }, "span": { "properties": { @@ -392,60 +672,127 @@ } } }, - "transaction": { + "stream": { "properties": { - "decoding": { + "accepted": { + "type": "long" + }, + "errors": { "properties": { - "count": { + "invalid": { "type": "long" }, - "errors": { + "toolarge": { "type": "long" } } - }, - "frames": { - "type": "long" - }, - "spans": { + } + } + }, + "transaction": { + "properties": { + "transformations": { "type": "long" - }, - "stacktraces": { + } + } + } + } + }, + "root": { + "properties": { + "request": { + "properties": { + "count": { "type": "long" - }, - "transactions": { + } + } + }, + "response": { + "properties": { + "count": { "type": "long" }, - "transformations": { - "type": "long" + "errors": { + "properties": { + "closed": { + "type": "long" + }, + "count": { + "type": "long" + }, + "decode": { + "type": "long" + }, + "forbidden": { + "type": "long" + }, + "internal": { + "type": "long" + }, + "invalidquery": { + "type": "long" + }, + "method": { + "type": "long" + }, + "notfound": { + "type": "long" + }, + "queue": { + "type": "long" + }, + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "toolarge": { + "type": "long" + }, + "unauthorized": { + "type": "long" + }, + "unavailable": { + "type": "long" + }, + "validate": { + "type": "long" + } + } }, - "validation": { + "valid": { "properties": { + "accepted": { + "type": "long" + }, "count": { "type": "long" }, - "errors": { + "notmodified": { + "type": "long" + }, + "ok": { "type": "long" } } } } + }, + "unset": { + "type": "long" + } + } + }, + "sampling": { + "properties": { + "transactions_dropped": { + "type": "long" } } }, "server": { "properties": { - "concurrent": { - "properties": { - "wait": { - "properties": { - "ms": { - "type": "long" - } - } - } - } - }, "request": { "properties": { "count": { @@ -478,21 +825,33 @@ "internal": { "type": "long" }, + "invalidquery": { + "type": "long" + }, "method": { "type": "long" }, + "notfound": { + "type": "long" + }, "queue": { "type": "long" }, "ratelimit": { "type": "long" }, + "timeout": { + "type": "long" + }, "toolarge": { "type": "long" }, "unauthorized": { "type": "long" }, + "unavailable": { + "type": "long" + }, "validate": { "type": "long" } @@ -506,12 +865,18 @@ "count": { "type": "long" }, + "notmodified": { + "type": "long" + }, "ok": { "type": "long" } } } } + }, + "unset": { + "type": "long" } } } @@ -918,6 +1283,37 @@ "type": "long" } } + }, + "output": { + "properties": { + "elasticsearch": { + "properties": { + "bulk_requests": { + "properties": { + "available": { + "type": "long" + }, + "completed": { + "type": "long" + } + } + }, + "indexers": { + "properties": { + "active": { + "type": "long" + }, + "created": { + "type": "long" + }, + "destroyed": { + "type": "long" + } + } + } + } + } + } } } }, @@ -1135,6 +1531,10 @@ "type": "alias", "path": "beat.stats.apm_server.acm.response.errors.ratelimit" }, + "timeout": { + "type": "alias", + "path": "beat.stats.apm_server.acm.response.errors.timeout" + }, "toolarge": { "type": "alias", "path": "beat.stats.apm_server.acm.response.errors.toolarge" @@ -1153,18 +1553,6 @@ } } }, - "request": { - "properties": { - "count": { - "type": "alias", - "path": "beat.stats.apm_server.acm.response.request.count" - } - } - }, - "unset": { - "type": "alias", - "path": "beat.stats.apm_server.acm.response.unset" - }, "valid": { "properties": { "accepted": { @@ -1179,9 +1567,485 @@ "type": "alias", "path": "beat.stats.apm_server.acm.response.valid.notmodified" }, - "ok": { - "type": "alias", - "path": "beat.stats.apm_server.acm.response.valid.ok" + "ok": { + "type": "alias", + "path": "beat.stats.apm_server.acm.response.valid.ok" + } + } + } + } + }, + "unset": { + "type": "alias", + "path": "beat.stats.apm_server.acm.unset" + } + } + }, + "agentcfg": { + "properties": { + "elasticsearch": { + "properties": { + "cache": { + "properties": { + "entries": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.agentcfg.elasticsearch.cache.entries.count" + } + } + }, + "refresh": { + "properties": { + "failures": { + "type": "alias", + "path": "beat.stats.apm_server.agentcfg.elasticsearch.cache.refresh.failures" + }, + "successes": { + "type": "alias", + "path": "beat.stats.apm_server.agentcfg.elasticsearch.cache.refresh.successes" + } + } + } + } + }, + "fetch": { + "properties": { + "es": { + "type": "alias", + "path": "beat.stats.apm_server.agentcfg.elasticsearch.fetch.es" + }, + "fallback": { + "type": "alias", + "path": "beat.stats.apm_server.agentcfg.elasticsearch.fetch.fallback" + }, + "invalid": { + "type": "alias", + "path": "beat.stats.apm_server.agentcfg.elasticsearch.fetch.invalid" + }, + "unavailable": { + "type": "alias", + "path": "beat.stats.apm_server.agentcfg.elasticsearch.fetch.unavailable" + } + } + } + } + } + } + }, + "jaeger": { + "properties": { + "grpc": { + "properties": { + "collect": { + "properties": { + "request": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.jaeger.grpc.collect.request.count" + } + } + }, + "response": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.jaeger.grpc.collect.response.count" + }, + "errors": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.jaeger.grpc.collect.response.errors.count" + }, + "ratelimit": { + "type": "alias", + "path": "beat.stats.apm_server.jaeger.grpc.collect.response.errors.ratelimit" + }, + "timeout": { + "type": "alias", + "path": "beat.stats.apm_server.jaeger.grpc.collect.response.errors.timeout" + }, + "unauthorized": { + "type": "alias", + "path": "beat.stats.apm_server.jaeger.grpc.collect.response.errors.unauthorized" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.jaeger.grpc.collect.response.valid.count" + } + } + } + } + } + } + }, + "sampling": { + "properties": { + "event": { + "properties": { + "received": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.jaeger.grpc.sampling.event.received.count" + } + } + } + } + }, + "request": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.jaeger.grpc.sampling.request.count" + } + } + }, + "response": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.jaeger.grpc.sampling.response.count" + }, + "errors": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.jaeger.grpc.sampling.response.errors.count" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.jaeger.grpc.sampling.response.valid.count" + } + } + } + } + } + } + } + } + } + } + }, + "otlp": { + "properties": { + "grpc": { + "properties": { + "logs": { + "properties": { + "request": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.logs.request.count" + } + } + }, + "response": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.logs.response.count" + }, + "errors": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.logs.response.errors.count" + }, + "ratelimit": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.logs.response.errors.ratelimit" + }, + "timeout": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.logs.response.errors.timeout" + }, + "unauthorized": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.logs.response.errors.unauthorized" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.logs.response.valid.count" + } + } + } + } + } + } + }, + "metrics": { + "properties": { + "consumer": { + "properties": { + "unsupported_dropped": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.metrics.consumer.unsupported_dropped" + } + } + }, + "request": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.metrics.request.count" + } + } + }, + "response": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.metrics.response.count" + }, + "errors": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.metrics.response.errors.count" + }, + "ratelimit": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.metrics.response.errors.ratelimit" + }, + "timeout": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.metrics.response.errors.timeout" + }, + "unauthorized": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.metrics.response.errors.unauthorized" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.metrics.response.valid.count" + } + } + } + } + } + } + }, + "traces": { + "properties": { + "request": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.traces.request.count" + } + } + }, + "response": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.traces.response.count" + }, + "errors": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.traces.response.errors.count" + }, + "ratelimit": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.traces.response.errors.ratelimit" + }, + "timeout": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.traces.response.errors.timeout" + }, + "unauthorized": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.traces.response.errors.unauthorized" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.grpc.traces.response.valid.count" + } + } + } + } + } + } + } + } + }, + "http": { + "properties": { + "logs": { + "properties": { + "request": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.logs.request.count" + } + } + }, + "response": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.logs.response.count" + }, + "errors": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.logs.response.errors.count" + }, + "ratelimit": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.logs.response.errors.ratelimit" + }, + "timeout": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.logs.response.errors.timeout" + }, + "unauthorized": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.logs.response.errors.unauthorized" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.logs.response.valid.count" + } + } + } + } + } + } + }, + "metrics": { + "properties": { + "consumer": { + "properties": { + "unsupported_dropped": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.metrics.consumer.unsupported_dropped" + } + } + }, + "request": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.metrics.request.count" + } + } + }, + "response": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.metrics.response.count" + }, + "errors": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.metrics.response.errors.count" + }, + "ratelimit": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.metrics.response.errors.ratelimit" + }, + "timeout": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.metrics.response.errors.timeout" + }, + "unauthorized": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.metrics.response.errors.unauthorized" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.metrics.response.valid.count" + } + } + } + } + } + } + }, + "traces": { + "properties": { + "request": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.traces.request.count" + } + } + }, + "response": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.traces.response.count" + }, + "errors": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.traces.response.errors.count" + }, + "ratelimit": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.traces.response.errors.ratelimit" + }, + "timeout": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.traces.response.errors.timeout" + }, + "unauthorized": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.traces.response.errors.unauthorized" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "alias", + "path": "beat.stats.apm_server.otlp.http.traces.response.valid.count" + } + } + } + } } } } @@ -1189,248 +2053,180 @@ } } }, - "decoder": { + "processor": { "properties": { - "deflate": { + "error": { "properties": { - "content-length": { - "type": "alias", - "path": "beat.stats.apm_server.decoder.deflate.content-length" - }, - "count": { + "transformations": { "type": "alias", - "path": "beat.stats.apm_server.decoder.deflate.count" + "path": "beat.stats.apm_server.processor.error.transformations" } } }, - "gzip": { + "metric": { "properties": { - "content-length": { - "type": "alias", - "path": "beat.stats.apm_server.decoder.gzip.content-length" - }, - "count": { + "transformations": { "type": "alias", - "path": "beat.stats.apm_server.decoder.gzip.count" + "path": "beat.stats.apm_server.processor.metric.transformations" } } }, - "missing-content-length": { + "span": { "properties": { - "count": { + "transformations": { "type": "alias", - "path": "beat.stats.apm_server.decoder.missing-content-length.count" + "path": "beat.stats.apm_server.processor.span.transformations" } } }, - "reader": { + "stream": { "properties": { - "count": { + "accepted": { "type": "alias", - "path": "beat.stats.apm_server.decoder.reader.count" + "path": "beat.stats.apm_server.processor.stream.accepted" }, - "size": { - "type": "alias", - "path": "beat.stats.apm_server.decoder.reader.size" + "errors": { + "properties": { + "invalid": { + "type": "alias", + "path": "beat.stats.apm_server.processor.stream.errors.invalid" + }, + "toolarge": { + "type": "alias", + "path": "beat.stats.apm_server.processor.stream.errors.toolarge" + } + } } } }, - "uncompressed": { + "transaction": { "properties": { - "content-length": { - "type": "alias", - "path": "beat.stats.apm_server.decoder.uncompressed.content-length" - }, - "count": { + "transformations": { "type": "alias", - "path": "beat.stats.apm_server.decoder.uncompressed.count" + "path": "beat.stats.apm_server.processor.transaction.transformations" } } } } }, - "processor": { + "root": { "properties": { - "error": { + "request": { "properties": { - "decoding": { - "properties": { - "count": { - "type": "alias", - "path": "beat.stats.apm_server.processor.error.decoding.count" - }, - "errors": { - "type": "alias", - "path": "beat.stats.apm_server.processor.error.decoding.errors" - } - } - }, - "frames": { - "type": "alias", - "path": "beat.stats.apm_server.processor.error.frames" - }, - "spans": { - "type": "alias", - "path": "beat.stats.apm_server.processor.error.spans" - }, - "stacktraces": { + "count": { "type": "alias", - "path": "beat.stats.apm_server.processor.error.stacktraces" - }, - "transformations": { + "path": "beat.stats.apm_server.root.request.count" + } + } + }, + "response": { + "properties": { + "count": { "type": "alias", - "path": "beat.stats.apm_server.processor.error.transformations" + "path": "beat.stats.apm_server.root.response.count" }, - "validation": { + "errors": { "properties": { + "closed": { + "type": "alias", + "path": "beat.stats.apm_server.root.response.errors.closed" + }, "count": { "type": "alias", - "path": "beat.stats.apm_server.processor.error.validation.count" + "path": "beat.stats.apm_server.root.response.errors.count" }, - "errors": { + "decode": { "type": "alias", - "path": "beat.stats.apm_server.processor.error.validation.errors" - } - } - } - } - }, - "metric": { - "properties": { - "decoding": { - "properties": { - "count": { + "path": "beat.stats.apm_server.root.response.errors.decode" + }, + "forbidden": { "type": "alias", - "path": "beat.stats.apm_server.processor.metric.decoding.count" + "path": "beat.stats.apm_server.root.response.errors.forbidden" }, - "errors": { + "internal": { "type": "alias", - "path": "beat.stats.apm_server.processor.metric.decoding.errors" - } - } - }, - "transformations": { - "type": "alias", - "path": "beat.stats.apm_server.processor.metric.transformations" - }, - "validation": { - "properties": { - "count": { + "path": "beat.stats.apm_server.root.response.errors.internal" + }, + "invalidquery": { "type": "alias", - "path": "beat.stats.apm_server.processor.metric.validation.count" + "path": "beat.stats.apm_server.root.response.errors.invalidquery" }, - "errors": { + "method": { "type": "alias", - "path": "beat.stats.apm_server.processor.metric.validation.errors" - } - } - } - } - }, - "sourcemap": { - "properties": { - "counter": { - "type": "alias", - "path": "beat.stats.apm_server.processor.sourcemap.counter" - }, - "decoding": { - "properties": { - "count": { + "path": "beat.stats.apm_server.root.response.errors.method" + }, + "notfound": { "type": "alias", - "path": "beat.stats.apm_server.processor.sourcemap.decoding.count" + "path": "beat.stats.apm_server.root.response.errors.notfound" }, - "errors": { + "queue": { "type": "alias", - "path": "beat.stats.apm_server.processor.sourcemap.decoding.errors" - } - } - }, - "validation": { - "properties": { - "count": { + "path": "beat.stats.apm_server.root.response.errors.queue" + }, + "ratelimit": { "type": "alias", - "path": "beat.stats.apm_server.processor.sourcemap.validation.count" + "path": "beat.stats.apm_server.root.response.errors.ratelimit" }, - "errors": { + "timeout": { "type": "alias", - "path": "beat.stats.apm_server.processor.sourcemap.validation.errors" - } - } - } - } - }, - "span": { - "properties": { - "transformations": { - "type": "alias", - "path": "beat.stats.apm_server.processor.span.transformations" - } - } - }, - "transaction": { - "properties": { - "decoding": { - "properties": { - "count": { + "path": "beat.stats.apm_server.root.response.errors.timeout" + }, + "toolarge": { + "type": "alias", + "path": "beat.stats.apm_server.root.response.errors.toolarge" + }, + "unauthorized": { "type": "alias", - "path": "beat.stats.apm_server.processor.transaction.decoding.count" + "path": "beat.stats.apm_server.root.response.errors.unauthorized" }, - "errors": { + "unavailable": { + "type": "alias", + "path": "beat.stats.apm_server.root.response.errors.unavailable" + }, + "validate": { "type": "alias", - "path": "beat.stats.apm_server.processor.transaction.decoding.errors" + "path": "beat.stats.apm_server.root.response.errors.validate" } } }, - "frames": { - "type": "alias", - "path": "beat.stats.apm_server.processor.transaction.frames" - }, - "spans": { - "type": "alias", - "path": "beat.stats.apm_server.processor.transaction.spans" - }, - "stacktraces": { - "type": "alias", - "path": "beat.stats.apm_server.processor.transaction.stacktraces" - }, - "transactions": { - "type": "alias", - "path": "beat.stats.apm_server.processor.transaction.transactions" - }, - "transformations": { - "type": "alias", - "path": "beat.stats.apm_server.processor.transaction.transformations" - }, - "validation": { + "valid": { "properties": { + "accepted": { + "type": "alias", + "path": "beat.stats.apm_server.root.response.valid.accepted" + }, "count": { "type": "alias", - "path": "beat.stats.apm_server.processor.transaction.validation.count" + "path": "beat.stats.apm_server.root.response.valid.count" }, - "errors": { + "notmodified": { + "type": "alias", + "path": "beat.stats.apm_server.root.response.valid.notmodified" + }, + "ok": { "type": "alias", - "path": "beat.stats.apm_server.processor.transaction.validation.errors" + "path": "beat.stats.apm_server.root.response.valid.ok" } } } } + }, + "unset": { + "type": "alias", + "path": "beat.stats.apm_server.root.unset" + } + } + }, + "sampling": { + "properties": { + "transactions_dropped": { + "type": "alias", + "path": "beat.stats.apm_server.sampling.transactions_dropped" } } }, "server": { "properties": { - "concurrent": { - "properties": { - "wait": { - "properties": { - "ms": { - "type": "alias", - "path": "beat.stats.apm_server.server.concurrent.wait.ms" - } - } - } - } - }, "request": { "properties": { "count": { @@ -1471,10 +2267,18 @@ "type": "alias", "path": "beat.stats.apm_server.server.response.errors.internal" }, + "invalidquery": { + "type": "alias", + "path": "beat.stats.apm_server.server.response.errors.invalidquery" + }, "method": { "type": "alias", "path": "beat.stats.apm_server.server.response.errors.method" }, + "notfound": { + "type": "alias", + "path": "beat.stats.apm_server.server.response.errors.notfound" + }, "queue": { "type": "alias", "path": "beat.stats.apm_server.server.response.errors.queue" @@ -1483,6 +2287,10 @@ "type": "alias", "path": "beat.stats.apm_server.server.response.errors.ratelimit" }, + "timeout": { + "type": "alias", + "path": "beat.stats.apm_server.server.response.errors.timeout" + }, "toolarge": { "type": "alias", "path": "beat.stats.apm_server.server.response.errors.toolarge" @@ -1491,6 +2299,10 @@ "type": "alias", "path": "beat.stats.apm_server.server.response.errors.unauthorized" }, + "unavailable": { + "type": "alias", + "path": "beat.stats.apm_server.server.response.errors.unavailable" + }, "validate": { "type": "alias", "path": "beat.stats.apm_server.server.response.errors.validate" @@ -1507,6 +2319,10 @@ "type": "alias", "path": "beat.stats.apm_server.server.response.valid.count" }, + "notmodified": { + "type": "alias", + "path": "beat.stats.apm_server.server.response.valid.notmodified" + }, "ok": { "type": "alias", "path": "beat.stats.apm_server.server.response.valid.ok" @@ -1514,49 +2330,10 @@ } } } - } - } - }, - "sampling": { - "properties": { - "transactions_dropped": { - "type": "long" }, - "tail": { - "properties": { - "dynamic_service_groups": { - "type": "long" - }, - "storage": { - "properties": { - "lsm_size": { - "type": "long" - }, - "value_log_size": { - "type": "long" - } - } - }, - "events": { - "properties": { - "processed": { - "type": "long" - }, - "dropped": { - "type": "long" - }, - "stored": { - "type": "long" - }, - "sampled": { - "type": "long" - }, - "head_unsampled": { - "type": "long" - } - } - } - } + "unset": { + "type": "alias", + "path": "beat.stats.apm_server.server.unset" } } } @@ -1985,6 +2762,42 @@ } } } + }, + "output": { + "properties": { + "elasticsearch": { + "properties": { + "bulk_requests": { + "properties": { + "available": { + "type": "alias", + "path": "beat.stats.output.elasticsearch.bulk_requests.available" + }, + "completed": { + "type": "alias", + "path": "beat.stats.output.elasticsearch.bulk_requests.completed" + } + } + }, + "indexers": { + "properties": { + "active": { + "type": "alias", + "path": "beat.stats.output.elasticsearch.indexers.active" + }, + "created": { + "type": "alias", + "path": "beat.stats.output.elasticsearch.indexers.created" + }, + "destroyed": { + "type": "alias", + "path": "beat.stats.output.elasticsearch.indexers.destroyed" + } + } + } + } + } + } } } }, diff --git a/x-pack/plugin/core/template-resources/src/main/resources/monitoring-beats.json b/x-pack/plugin/core/template-resources/src/main/resources/monitoring-beats.json index 6dee05564cc10..d699317c29da3 100644 --- a/x-pack/plugin/core/template-resources/src/main/resources/monitoring-beats.json +++ b/x-pack/plugin/core/template-resources/src/main/resources/monitoring-beats.json @@ -346,17 +346,11 @@ "response": { "properties": { "count": { - "type": "long" + "type": "long" }, "errors": { "properties": { - "validate": { - "type": "long" - }, - "internal": { - "type": "long" - }, - "queue": { + "closed": { "type": "long" }, "count": { @@ -365,13 +359,13 @@ "decode": { "type": "long" }, - "toolarge": { + "forbidden": { "type": "long" }, - "unavailable": { + "internal": { "type": "long" }, - "forbidden": { + "invalidquery": { "type": "long" }, "method": { @@ -380,125 +374,454 @@ "notfound": { "type": "long" }, - "invalidquery": { + "queue": { "type": "long" }, "ratelimit": { "type": "long" }, - "closed": { + "timeout": { + "type": "long" + }, + "toolarge": { "type": "long" }, "unauthorized": { "type": "long" + }, + "unavailable": { + "type": "long" + }, + "validate": { + "type": "long" } } }, "valid": { "properties": { - "notmodified": { + "accepted": { "type": "long" }, "count": { "type": "long" }, - "ok": { + "notmodified": { "type": "long" }, - "accepted": { - "type": "long" - } - } - }, - "unset": { - "type": "long" - }, - "request": { - "properties": { - "count": { + "ok": { "type": "long" } } } } + }, + "unset": { + "type": "long" } } }, - "server": { + "agentcfg": { "properties": { - "request": { + "elasticsearch": { "properties": { - "count": { - "type": "long" - } - } - }, - "concurrent": { - "properties": { - "wait": { + "cache": { "properties": { - "ms": { - "type": "long" + "entries": { + "properties": { + "count": { + "type": "long" + } + } + }, + "refresh": { + "properties": { + "failures": { + "type": "long" + }, + "successes": { + "type": "long" + } + } } } - } - } - }, - "response": { - "properties": { - "count": { - "type": "long" }, - "errors": { + "fetch": { "properties": { - "count": { + "es": { "type": "long" }, - "toolarge": { + "fallback": { "type": "long" }, - "validate": { + "invalid": { "type": "long" }, - "ratelimit": { + "unavailable": { "type": "long" + } + } + } + } + } + } + }, + "jaeger": { + "properties": { + "grpc": { + "properties": { + "collect": { + "properties": { + "request": { + "properties": { + "count": { + "type": "long" + } + } }, - "queue": { - "type": "long" + "response": { + "properties": { + "count": { + "type": "long" + }, + "errors": { + "properties": { + "count": { + "type": "long" + }, + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "unauthorized": { + "type": "long" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "long" + } + } + } + } + } + } + }, + "sampling": { + "properties": { + "event": { + "properties": { + "received": { + "properties": { + "count": { + "type": "long" + } + } + } + } }, - "closed": { - "type": "long" + "request": { + "properties": { + "count": { + "type": "long" + } + } }, - "forbidden": { - "type": "long" + "response": { + "properties": { + "count": { + "type": "long" + }, + "errors": { + "properties": { + "count": { + "type": "long" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "long" + } + } + } + } + } + } + } + } + } + } + }, + "otlp": { + "properties": { + "grpc": { + "properties": { + "logs": { + "properties": { + "request": { + "properties": { + "count": { + "type": "long" + } + } }, - "concurrency": { - "type": "long" + "response": { + "properties": { + "count": { + "type": "long" + }, + "errors": { + "properties": { + "count": { + "type": "long" + }, + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "unauthorized": { + "type": "long" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "long" + } + } + } + } + } + } + }, + "metrics": { + "properties": { + "consumer": { + "properties": { + "unsupported_dropped": { + "type": "long" + } + } }, - "unauthorized": { - "type": "long" + "request": { + "properties": { + "count": { + "type": "long" + } + } }, - "internal": { - "type": "long" + "response": { + "properties": { + "count": { + "type": "long" + }, + "errors": { + "properties": { + "count": { + "type": "long" + }, + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "unauthorized": { + "type": "long" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "long" + } + } + } + } + } + } + }, + "traces": { + "properties": { + "request": { + "properties": { + "count": { + "type": "long" + } + } }, - "decode": { - "type": "long" + "response": { + "properties": { + "count": { + "type": "long" + }, + "errors": { + "properties": { + "count": { + "type": "long" + }, + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "unauthorized": { + "type": "long" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "long" + } + } + } + } + } + } + } + } + }, + "http": { + "properties": { + "logs": { + "properties": { + "request": { + "properties": { + "count": { + "type": "long" + } + } }, - "method": { - "type": "long" + "response": { + "properties": { + "count": { + "type": "long" + }, + "errors": { + "properties": { + "count": { + "type": "long" + }, + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "unauthorized": { + "type": "long" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "long" + } + } + } + } } } }, - "valid": { + "metrics": { "properties": { - "ok": { - "type": "long" + "consumer": { + "properties": { + "unsupported_dropped": { + "type": "long" + } + } }, - "accepted": { - "type": "long" + "request": { + "properties": { + "count": { + "type": "long" + } + } }, - "count": { - "type": "long" + "response": { + "properties": { + "count": { + "type": "long" + }, + "errors": { + "properties": { + "count": { + "type": "long" + }, + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "unauthorized": { + "type": "long" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "long" + } + } + } + } + } + } + }, + "traces": { + "properties": { + "request": { + "properties": { + "count": { + "type": "long" + } + } + }, + "response": { + "properties": { + "count": { + "type": "long" + }, + "errors": { + "properties": { + "count": { + "type": "long" + }, + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "unauthorized": { + "type": "long" + } + } + }, + "valid": { + "properties": { + "count": { + "type": "long" + } + } + } + } } } } @@ -506,195 +829,138 @@ } } }, - "decoder": { + "processor": { "properties": { - "deflate": { + "error": { "properties": { - "content-length": { - "type": "long" - }, - "count": { + "transformations": { "type": "long" } } }, - "gzip": { + "metric": { "properties": { - "content-length": { - "type": "long" - }, - "count": { + "transformations": { "type": "long" } } }, - "uncompressed": { + "span": { "properties": { - "content-length": { - "type": "long" - }, - "count": { + "transformations": { "type": "long" } } }, - "reader": { + "stream": { "properties": { - "size": { + "accepted": { "type": "long" }, - "count": { - "type": "long" + "errors": { + "properties": { + "invalid": { + "type": "long" + }, + "toolarge": { + "type": "long" + } + } } } }, - "missing-content-length": { + "transaction": { "properties": { - "count": { + "transformations": { "type": "long" } } } } - }, - "processor": { + "root": { "properties": { - "metric": { + "request": { "properties": { - "decoding": { - "properties": { - "errors": { - "type": "long" - }, - "count": { - "type": "long" - } - } - }, - "validation": { - "properties": { - "errors": { - "type": "long" - }, - "count": { - "type": "long" - } - } - }, - "transformations": { + "count": { "type": "long" } } }, - "sourcemap": { + "response": { "properties": { - "counter": { + "count": { "type": "long" }, - "decoding": { + "errors": { "properties": { - "errors": { + "closed": { "type": "long" }, "count": { "type": "long" - } - } - }, - "validation": { - "properties": { - "errors": { + }, + "decode": { "type": "long" }, - "count": { + "forbidden": { "type": "long" - } - } - } - } - }, - "transaction": { - "properties": { - "decoding": { - "properties": { - "errors": { + }, + "internal": { "type": "long" }, - "count": { + "invalidquery": { "type": "long" - } - } - }, - "validation": { - "properties": { - "errors": { + }, + "method": { "type": "long" }, - "count": { + "notfound": { "type": "long" - } - } - }, - "transformations": { - "type": "long" - }, - "transactions": { - "type": "long" - }, - "spans": { - "type": "long" - }, - "stacktraces": { - "type": "long" - }, - "frames": { - "type": "long" - } - } - }, - "error": { - "properties": { - "decoding": { - "properties": { - "errors": { + }, + "queue": { "type": "long" }, - "count": { + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "toolarge": { + "type": "long" + }, + "unauthorized": { + "type": "long" + }, + "unavailable": { + "type": "long" + }, + "validate": { "type": "long" } } }, - "validation": { + "valid": { "properties": { - "errors": { + "accepted": { "type": "long" }, "count": { "type": "long" + }, + "notmodified": { + "type": "long" + }, + "ok": { + "type": "long" } } - }, - "transformations": { - "type": "long" - }, - "errors": { - "type": "long" - }, - "stacktraces": { - "type": "long" - }, - "frames": { - "type": "long" } } }, - "span": { - "properties": { - "transformations": { - "type": "long" - } - } + "unset": { + "type": "long" } } }, @@ -702,42 +968,95 @@ "properties": { "transactions_dropped": { "type": "long" + } + } + }, + "server": { + "properties": { + "request": { + "properties": { + "count": { + "type": "long" + } + } }, - "tail": { + "response": { "properties": { - "dynamic_service_groups": { + "count": { "type": "long" }, - "storage": { + "errors": { "properties": { - "lsm_size": { + "closed": { "type": "long" }, - "value_log_size": { + "concurrency": { + "type": "long" + }, + "count": { + "type": "long" + }, + "decode": { + "type": "long" + }, + "forbidden": { + "type": "long" + }, + "internal": { + "type": "long" + }, + "invalidquery": { + "type": "long" + }, + "method": { + "type": "long" + }, + "notfound": { + "type": "long" + }, + "queue": { + "type": "long" + }, + "ratelimit": { + "type": "long" + }, + "timeout": { + "type": "long" + }, + "toolarge": { + "type": "long" + }, + "unauthorized": { + "type": "long" + }, + "unavailable": { + "type": "long" + }, + "validate": { "type": "long" } } }, - "events": { + "valid": { "properties": { - "processed": { - "type": "long" - }, - "dropped": { + "accepted": { "type": "long" }, - "stored": { + "count": { "type": "long" }, - "sampled": { + "notmodified": { "type": "long" }, - "head_unsampled": { + "ok": { "type": "long" } } } } + }, + "unset": { + "type": "long" } } } @@ -893,6 +1212,37 @@ } } } + }, + "output": { + "properties": { + "elasticsearch": { + "properties": { + "bulk_requests": { + "properties": { + "available": { + "type": "long" + }, + "completed": { + "type": "long" + } + } + }, + "indexers": { + "properties": { + "active": { + "type": "long" + }, + "created": { + "type": "long" + }, + "destroyed": { + "type": "long" + } + } + } + } + } + } } } }, diff --git a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/connector/10_connector_put.yml b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/connector/10_connector_put.yml index 5cfb016e1b6df..b0f850d09f76d 100644 --- a/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/connector/10_connector_put.yml +++ b/x-pack/plugin/ent-search/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/entsearch/connector/10_connector_put.yml @@ -76,6 +76,42 @@ setup: - match: { custom_scheduling: {} } - match: { filtering.0.domain: DEFAULT } + +--- +'Create Connector - Check for missing keys': + - do: + connector.put: + connector_id: test-connector + body: + index_name: search-test + name: my-connector + language: pl + is_native: false + service_type: super-connector + + - match: { result: 'created' } + + - do: + connector.get: + connector_id: test-connector + + - match: { id: test-connector } + - match: { index_name: search-test } + - match: { name: my-connector } + - match: { language: pl } + - match: { is_native: false } + - match: { service_type: super-connector } + + # check keys that are not populated upon connector creation + - is_false: api_key_id + - is_false: api_key_secret_id + - is_false: description + - is_false: error + - is_false: features + - is_false: last_seen + - is_false: sync_cursor + + --- 'Create Connector - Resource already exists': - do: diff --git a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/Connector.java b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/Connector.java index a9c488b024d49..46275bb623b7a 100644 --- a/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/Connector.java +++ b/x-pack/plugin/ent-search/src/main/java/org/elasticsearch/xpack/application/connector/Connector.java @@ -377,25 +377,61 @@ public void toInnerXContent(XContentBuilder builder, Params params) throws IOExc if (connectorId != null) { builder.field(ID_FIELD.getPreferredName(), connectorId); } - builder.field(API_KEY_ID_FIELD.getPreferredName(), apiKeyId); - builder.field(API_KEY_SECRET_ID_FIELD.getPreferredName(), apiKeySecretId); - builder.xContentValuesMap(CONFIGURATION_FIELD.getPreferredName(), configuration); - builder.xContentValuesMap(CUSTOM_SCHEDULING_FIELD.getPreferredName(), customScheduling); - builder.field(DESCRIPTION_FIELD.getPreferredName(), description); - builder.field(ERROR_FIELD.getPreferredName(), error); - builder.field(FEATURES_FIELD.getPreferredName(), features); - builder.xContentList(FILTERING_FIELD.getPreferredName(), filtering); - builder.field(INDEX_NAME_FIELD.getPreferredName(), indexName); + if (apiKeyId != null) { + builder.field(API_KEY_ID_FIELD.getPreferredName(), apiKeyId); + } + if (apiKeySecretId != null) { + builder.field(API_KEY_SECRET_ID_FIELD.getPreferredName(), apiKeySecretId); + } + if (configuration != null) { + builder.xContentValuesMap(CONFIGURATION_FIELD.getPreferredName(), configuration); + } + if (customScheduling != null) { + builder.xContentValuesMap(CUSTOM_SCHEDULING_FIELD.getPreferredName(), customScheduling); + } + if (description != null) { + builder.field(DESCRIPTION_FIELD.getPreferredName(), description); + } + if (error != null) { + builder.field(ERROR_FIELD.getPreferredName(), error); + } + if (features != null) { + builder.field(FEATURES_FIELD.getPreferredName(), features); + } + if (filtering != null) { + builder.xContentList(FILTERING_FIELD.getPreferredName(), filtering); + } + if (indexName != null) { + builder.field(INDEX_NAME_FIELD.getPreferredName(), indexName); + } builder.field(IS_NATIVE_FIELD.getPreferredName(), isNative); - builder.field(LANGUAGE_FIELD.getPreferredName(), language); - builder.field(LAST_SEEN_FIELD.getPreferredName(), lastSeen); - syncInfo.toXContent(builder, params); - builder.field(NAME_FIELD.getPreferredName(), name); - builder.field(PIPELINE_FIELD.getPreferredName(), pipeline); - builder.field(SCHEDULING_FIELD.getPreferredName(), scheduling); - builder.field(SERVICE_TYPE_FIELD.getPreferredName(), serviceType); - builder.field(SYNC_CURSOR_FIELD.getPreferredName(), syncCursor); - builder.field(STATUS_FIELD.getPreferredName(), status.toString()); + if (language != null) { + builder.field(LANGUAGE_FIELD.getPreferredName(), language); + } + if (lastSeen != null) { + builder.field(LAST_SEEN_FIELD.getPreferredName(), lastSeen); + } + if (syncInfo != null) { + syncInfo.toXContent(builder, params); + } + if (name != null) { + builder.field(NAME_FIELD.getPreferredName(), name); + } + if (pipeline != null) { + builder.field(PIPELINE_FIELD.getPreferredName(), pipeline); + } + if (scheduling != null) { + builder.field(SCHEDULING_FIELD.getPreferredName(), scheduling); + } + if (serviceType != null) { + builder.field(SERVICE_TYPE_FIELD.getPreferredName(), serviceType); + } + if (syncCursor != null) { + builder.field(SYNC_CURSOR_FIELD.getPreferredName(), syncCursor); + } + if (status != null) { + builder.field(STATUS_FIELD.getPreferredName(), status.toString()); + } builder.field(SYNC_NOW_FIELD.getPreferredName(), syncNow); } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/analyzer/VerifierChecks.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/analyzer/VerifierChecks.java deleted file mode 100644 index 36ce187d8600c..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/analyzer/VerifierChecks.java +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.core.analyzer; - -import org.elasticsearch.xpack.esql.core.common.Failure; -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; - -import java.util.Set; - -import static org.elasticsearch.xpack.esql.core.common.Failure.fail; -import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; - -public final class VerifierChecks { - - public static void checkFilterConditionType(LogicalPlan p, Set localFailures) { - if (p instanceof Filter) { - Expression condition = ((Filter) p).condition(); - if (condition.dataType() != BOOLEAN) { - localFailures.add(fail(condition, "Condition expression needs to be boolean, found [{}]", condition.dataType())); - } - } - } - -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/async/QlStatusResponse.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/async/QlStatusResponse.java deleted file mode 100644 index 8c28f08e8d882..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/async/QlStatusResponse.java +++ /dev/null @@ -1,200 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.async; - -import org.elasticsearch.ExceptionsHelper; -import org.elasticsearch.action.ActionResponse; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.core.async.StoredAsyncResponse; -import org.elasticsearch.xpack.core.search.action.SearchStatusResponse; - -import java.io.IOException; -import java.util.Objects; - -/** - * A response for *QL search status request - */ -public class QlStatusResponse extends ActionResponse implements SearchStatusResponse, ToXContentObject { - private final String id; - private final boolean isRunning; - private final boolean isPartial; - private final Long startTimeMillis; - private final long expirationTimeMillis; - private final RestStatus completionStatus; - - public interface AsyncStatus { - String id(); - - boolean isRunning(); - - boolean isPartial(); - } - - public QlStatusResponse( - String id, - boolean isRunning, - boolean isPartial, - Long startTimeMillis, - long expirationTimeMillis, - RestStatus completionStatus - ) { - this.id = id; - this.isRunning = isRunning; - this.isPartial = isPartial; - this.startTimeMillis = startTimeMillis; - this.expirationTimeMillis = expirationTimeMillis; - this.completionStatus = completionStatus; - } - - /** - * Get status from the stored Ql search response - * @param storedResponse - a response from a stored search - * @param expirationTimeMillis – expiration time in milliseconds - * @param id – encoded async search id - * @return a status response - */ - public static QlStatusResponse getStatusFromStoredSearch( - StoredAsyncResponse storedResponse, - long expirationTimeMillis, - String id - ) { - S searchResponse = storedResponse.getResponse(); - if (searchResponse != null) { - assert searchResponse.isRunning() == false : "Stored Ql search response must have a completed status!"; - return new QlStatusResponse( - searchResponse.id(), - false, - searchResponse.isPartial(), - null, // we don't store in the index the start time for completed response - expirationTimeMillis, - RestStatus.OK - ); - } else { - Exception exc = storedResponse.getException(); - assert exc != null : "Stored Ql response must either have a search response or an exception!"; - return new QlStatusResponse( - id, - false, - false, - null, // we don't store in the index the start time for completed response - expirationTimeMillis, - ExceptionsHelper.status(exc) - ); - } - } - - public QlStatusResponse(StreamInput in) throws IOException { - this.id = in.readString(); - this.isRunning = in.readBoolean(); - this.isPartial = in.readBoolean(); - this.startTimeMillis = in.readOptionalLong(); - this.expirationTimeMillis = in.readLong(); - this.completionStatus = (this.isRunning == false) ? RestStatus.readFrom(in) : null; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeString(id); - out.writeBoolean(isRunning); - out.writeBoolean(isPartial); - out.writeOptionalLong(startTimeMillis); - out.writeLong(expirationTimeMillis); - if (isRunning == false) { - RestStatus.writeTo(out, completionStatus); - } - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - { - builder.field("id", id); - builder.field("is_running", isRunning); - builder.field("is_partial", isPartial); - if (startTimeMillis != null) { // start time is available only for a running eql search - builder.timeField("start_time_in_millis", "start_time", startTimeMillis); - } - builder.timeField("expiration_time_in_millis", "expiration_time", expirationTimeMillis); - if (isRunning == false) { // completion status is available only for a completed eql search - builder.field("completion_status", completionStatus.getStatus()); - } - } - builder.endObject(); - return builder; - } - - @Override - public boolean equals(Object obj) { - if (this == obj) return true; - if (obj == null || getClass() != obj.getClass()) return false; - QlStatusResponse other = (QlStatusResponse) obj; - return id.equals(other.id) - && isRunning == other.isRunning - && isPartial == other.isPartial - && Objects.equals(startTimeMillis, other.startTimeMillis) - && expirationTimeMillis == other.expirationTimeMillis - && Objects.equals(completionStatus, other.completionStatus); - } - - @Override - public int hashCode() { - return Objects.hash(id, isRunning, isPartial, startTimeMillis, expirationTimeMillis, completionStatus); - } - - /** - * Returns the id of the eql search status request. - */ - public String getId() { - return id; - } - - /** - * Returns {@code true} if the eql search is still running in the cluster, - * or {@code false} if the search has been completed. - */ - public boolean isRunning() { - return isRunning; - } - - /** - * Returns {@code true} if the eql search results are partial. - * This could be either because eql search hasn't finished yet, - * or if it finished and some shards have failed or timed out. - */ - public boolean isPartial() { - return isPartial; - } - - /** - * Returns a timestamp when the eql search task started, in milliseconds since epoch. - * For a completed eql search returns {@code null}, as we don't store start time for completed searches. - */ - public Long getStartTime() { - return startTimeMillis; - } - - /** - * Returns a timestamp when the eql search will be expired, in milliseconds since epoch. - */ - @Override - public long getExpirationTime() { - return expirationTimeMillis; - } - - /** - * For a completed eql search returns the completion status. - * For a still running eql search returns {@code null}. - */ - public RestStatus getCompletionStatus() { - return completionStatus; - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/parser/CaseChangingCharStream.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/parser/CaseChangingCharStream.java index f38daa472ddff..6248004d73dac 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/parser/CaseChangingCharStream.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/parser/CaseChangingCharStream.java @@ -18,27 +18,24 @@ /** * This class supports case-insensitive lexing by wrapping an existing - * {@link CharStream} and forcing the lexer to see either upper or - * lowercase characters. Grammar literals should then be either upper or - * lower case such as 'BEGIN' or 'begin'. The text of the character - * stream is unaffected. Example: input 'BeGiN' would match lexer rule - * 'BEGIN' if constructor parameter upper=true but getText() would return - * 'BeGiN'. + * {@link CharStream} and forcing the lexer to see lowercase characters + * Grammar literals should then be lower case such as {@code begin}. + * The text of the character stream is unaffected. + *

Example: input {@code BeGiN} would match lexer rule {@code begin} + * but {@link CharStream#getText} will return {@code BeGiN}. + *

*/ public class CaseChangingCharStream implements CharStream { private final CharStream stream; - private final boolean upper; /** * Constructs a new CaseChangingCharStream wrapping the given {@link CharStream} forcing * all characters to upper case or lower case. * @param stream The stream to wrap. - * @param upper If true force each symbol to upper case, otherwise force to lower. */ - public CaseChangingCharStream(CharStream stream, boolean upper) { + public CaseChangingCharStream(CharStream stream) { this.stream = stream; - this.upper = upper; } @Override @@ -57,7 +54,7 @@ public int LA(int i) { if (c <= 0) { return c; } - return upper ? Character.toUpperCase(c) : Character.toLowerCase(c); + return Character.toLowerCase(c); } @Override diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plugin/AbstractTransportQlAsyncGetStatusAction.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plugin/AbstractTransportQlAsyncGetStatusAction.java deleted file mode 100644 index cb21272758d1b..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plugin/AbstractTransportQlAsyncGetStatusAction.java +++ /dev/null @@ -1,111 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.plugin; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.ActionListenerResponseHandler; -import org.elasticsearch.action.ActionResponse; -import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.action.support.HandledTransportAction; -import org.elasticsearch.client.internal.Client; -import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.common.io.stream.NamedWriteableRegistry; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.util.BigArrays; -import org.elasticsearch.common.util.concurrent.EsExecutors; -import org.elasticsearch.tasks.Task; -import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.transport.TransportService; -import org.elasticsearch.xpack.core.XPackPlugin; -import org.elasticsearch.xpack.core.async.AsyncExecutionId; -import org.elasticsearch.xpack.core.async.AsyncTaskIndexService; -import org.elasticsearch.xpack.core.async.GetAsyncStatusRequest; -import org.elasticsearch.xpack.core.async.StoredAsyncResponse; -import org.elasticsearch.xpack.core.async.StoredAsyncTask; -import org.elasticsearch.xpack.esql.core.async.QlStatusResponse; - -import java.util.Objects; - -import static org.elasticsearch.xpack.core.ClientHelper.ASYNC_SEARCH_ORIGIN; - -public abstract class AbstractTransportQlAsyncGetStatusAction< - Response extends ActionResponse & QlStatusResponse.AsyncStatus, - AsyncTask extends StoredAsyncTask> extends HandledTransportAction { - private final String actionName; - private final TransportService transportService; - private final ClusterService clusterService; - private final Class asyncTaskClass; - private final AsyncTaskIndexService> store; - - @SuppressWarnings("this-escape") - public AbstractTransportQlAsyncGetStatusAction( - String actionName, - TransportService transportService, - ActionFilters actionFilters, - ClusterService clusterService, - NamedWriteableRegistry registry, - Client client, - ThreadPool threadPool, - BigArrays bigArrays, - Class asyncTaskClass - ) { - super(actionName, transportService, actionFilters, GetAsyncStatusRequest::new, EsExecutors.DIRECT_EXECUTOR_SERVICE); - this.actionName = actionName; - this.transportService = transportService; - this.clusterService = clusterService; - this.asyncTaskClass = asyncTaskClass; - Writeable.Reader> reader = in -> new StoredAsyncResponse<>(responseReader(), in); - this.store = new AsyncTaskIndexService<>( - XPackPlugin.ASYNC_RESULTS_INDEX, - clusterService, - threadPool.getThreadContext(), - client, - ASYNC_SEARCH_ORIGIN, - reader, - registry, - bigArrays - ); - } - - @Override - protected void doExecute(Task task, GetAsyncStatusRequest request, ActionListener listener) { - AsyncExecutionId searchId = AsyncExecutionId.decode(request.getId()); - DiscoveryNode node = clusterService.state().nodes().get(searchId.getTaskId().getNodeId()); - DiscoveryNode localNode = clusterService.state().getNodes().getLocalNode(); - if (node == null || Objects.equals(node, localNode)) { - store.retrieveStatus( - request, - taskManager, - asyncTaskClass, - AbstractTransportQlAsyncGetStatusAction::getStatusResponse, - QlStatusResponse::getStatusFromStoredSearch, - listener - ); - } else { - transportService.sendRequest( - node, - actionName, - request, - new ActionListenerResponseHandler<>(listener, QlStatusResponse::new, EsExecutors.DIRECT_EXECUTOR_SERVICE) - ); - } - } - - private static QlStatusResponse getStatusResponse(StoredAsyncTask asyncTask) { - return new QlStatusResponse( - asyncTask.getExecutionId().getEncoded(), - true, - true, - asyncTask.getStartTime(), - asyncTask.getExpirationTimeMillis(), - null - ); - } - - protected abstract Writeable.Reader responseReader(); -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plugin/TransportActionUtils.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plugin/TransportActionUtils.java deleted file mode 100644 index 4d6fc9d1d18d5..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plugin/TransportActionUtils.java +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.plugin; - -import org.apache.logging.log4j.Logger; -import org.elasticsearch.action.search.SearchPhaseExecutionException; -import org.elasticsearch.action.search.VersionMismatchException; -import org.elasticsearch.cluster.node.DiscoveryNode; -import org.elasticsearch.cluster.service.ClusterService; -import org.elasticsearch.xpack.esql.core.util.Holder; - -import java.util.function.Consumer; - -public final class TransportActionUtils { - - /** - * Execute a *QL request and re-try it in case the first request failed with a {@code VersionMismatchException} - * - * @param clusterService The cluster service instance - * @param onFailure On-failure handler in case the request doesn't fail with a {@code VersionMismatchException} - * @param queryRunner *QL query execution code, typically a Plan Executor running the query - * @param retryRequest Re-trial logic - * @param log Log4j logger - */ - public static void executeRequestWithRetryAttempt( - ClusterService clusterService, - Consumer onFailure, - Consumer> queryRunner, - Consumer retryRequest, - Logger log - ) { - - Holder retrySecondTime = new Holder(false); - queryRunner.accept(e -> { - // the search request likely ran on nodes with different versions of ES - // we will retry on a node with an older version that should generate a backwards compatible _search request - if (e instanceof SearchPhaseExecutionException - && ((SearchPhaseExecutionException) e).getCause() instanceof VersionMismatchException) { - if (log.isDebugEnabled()) { - log.debug("Caught exception type [{}] with cause [{}].", e.getClass().getName(), e.getCause()); - } - DiscoveryNode localNode = clusterService.state().nodes().getLocalNode(); - DiscoveryNode candidateNode = null; - for (DiscoveryNode node : clusterService.state().nodes()) { - // find the first node that's older than the current node - if (node != localNode && node.getVersion().before(localNode.getVersion())) { - candidateNode = node; - break; - } - } - if (candidateNode != null) { - if (log.isDebugEnabled()) { - log.debug( - "Candidate node to resend the request to: address [{}], id [{}], name [{}], version [{}]", - candidateNode.getAddress(), - candidateNode.getId(), - candidateNode.getName(), - candidateNode.getVersion() - ); - } - // re-send the request to the older node - retryRequest.accept(candidateNode); - } else { - retrySecondTime.set(true); - } - } else { - onFailure.accept(e); - } - }); - if (retrySecondTime.get()) { - if (log.isDebugEnabled()) { - log.debug("No candidate node found, likely all were upgraded in the meantime. Re-trying the original request."); - } - queryRunner.accept(onFailure); - } - } -} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/ActionListeners.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/ActionListeners.java deleted file mode 100644 index 025f9c2b6fd7a..0000000000000 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/util/ActionListeners.java +++ /dev/null @@ -1,26 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.core.util; - -import org.elasticsearch.action.ActionListener; -import org.elasticsearch.core.CheckedConsumer; -import org.elasticsearch.core.CheckedFunction; - -import java.util.function.Consumer; - -public class ActionListeners { - - private ActionListeners() {} - - /** - * Combination of {@link ActionListener#wrap(CheckedConsumer, Consumer)} and {@link ActionListener#map} - */ - public static ActionListener map(ActionListener delegate, CheckedFunction fn) { - return delegate.delegateFailureAndWrap((l, r) -> l.onResponse(fn.apply(r))); - } -} diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/action/QlStatusResponseTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/action/QlStatusResponseTests.java deleted file mode 100644 index e38755b703913..0000000000000 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/action/QlStatusResponseTests.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ -package org.elasticsearch.xpack.esql.core.action; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.xcontent.XContentHelper; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.AbstractWireSerializingTestCase; -import org.elasticsearch.xcontent.ToXContent; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.esql.core.async.QlStatusResponse; - -import java.io.IOException; -import java.util.Date; - -import static org.elasticsearch.xpack.core.async.GetAsyncResultRequestTests.randomSearchId; - -public class QlStatusResponseTests extends AbstractWireSerializingTestCase { - - @Override - protected QlStatusResponse createTestInstance() { - String id = randomSearchId(); - boolean isRunning = randomBoolean(); - boolean isPartial = isRunning ? randomBoolean() : false; - long randomDate = (new Date(randomLongBetween(0, 3000000000000L))).getTime(); - Long startTimeMillis = randomBoolean() ? null : randomDate; - long expirationTimeMillis = startTimeMillis == null ? randomDate : startTimeMillis + 3600000L; - RestStatus completionStatus = isRunning ? null : randomBoolean() ? RestStatus.OK : RestStatus.SERVICE_UNAVAILABLE; - return new QlStatusResponse(id, isRunning, isPartial, startTimeMillis, expirationTimeMillis, completionStatus); - } - - @Override - protected Writeable.Reader instanceReader() { - return QlStatusResponse::new; - } - - @Override - protected QlStatusResponse mutateInstance(QlStatusResponse instance) { - // return a response with the opposite running status - boolean isRunning = instance.isRunning() == false; - boolean isPartial = isRunning ? randomBoolean() : false; - RestStatus completionStatus = isRunning ? null : randomBoolean() ? RestStatus.OK : RestStatus.SERVICE_UNAVAILABLE; - return new QlStatusResponse( - instance.getId(), - isRunning, - isPartial, - instance.getStartTime(), - instance.getExpirationTime(), - completionStatus - ); - } - - public void testToXContent() throws IOException { - QlStatusResponse response = createTestInstance(); - try (XContentBuilder builder = XContentBuilder.builder(XContentType.JSON.xContent())) { - Object[] args = new Object[] { - response.getId(), - response.isRunning(), - response.isPartial(), - response.getStartTime() != null ? "\"start_time_in_millis\" : " + response.getStartTime() + "," : "", - response.getExpirationTime(), - response.getCompletionStatus() != null ? ", \"completion_status\" : " + response.getCompletionStatus().getStatus() : "" }; - String expectedJson = Strings.format(""" - { - "id" : "%s", - "is_running" : %s, - "is_partial" : %s, - %s - "expiration_time_in_millis" : %s - %s - } - """, args); - response.toXContent(builder, ToXContent.EMPTY_PARAMS); - assertEquals(XContentHelper.stripWhitespace(expectedJson), Strings.toString(builder)); - } - } -} diff --git a/x-pack/plugin/esql-core/test-fixtures/src/main/java/org/elasticsearch/xpack/esql/core/CsvSpecReader.java b/x-pack/plugin/esql-core/test-fixtures/src/main/java/org/elasticsearch/xpack/esql/core/CsvSpecReader.java index a1f524e525eee..8e5a228af00d6 100644 --- a/x-pack/plugin/esql-core/test-fixtures/src/main/java/org/elasticsearch/xpack/esql/core/CsvSpecReader.java +++ b/x-pack/plugin/esql-core/test-fixtures/src/main/java/org/elasticsearch/xpack/esql/core/CsvSpecReader.java @@ -15,7 +15,6 @@ import static org.hamcrest.CoreMatchers.is; import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; public final class CsvSpecReader { @@ -113,34 +112,16 @@ public static class CsvTestCase { public boolean ignoreOrder; public List requiredCapabilities = List.of(); - // The emulated-specific warnings must always trail the non-emulated ones, if these are present. Otherwise, the closing bracket - // would need to be changed to a less common sequence (like `]#` maybe). - private static final String EMULATED_PREFIX = "#[emulated:"; - /** * Returns the warning headers expected to be added by the test. To declare such a header, use the `warning:definition` format * in the CSV test declaration. The `definition` can use the `EMULATED_PREFIX` string to specify the format of the warning run on * emulated physical operators, if this differs from the format returned by SingleValueQuery. - * @param forEmulated if true, the tests are run on emulated physical operators; if false, the test case is for queries executed - * on a "full stack" ESQL, having data loaded from Lucene. * @return the list of headers that are expected to be returned part of the response. */ - public List expectedWarnings(boolean forEmulated) { + public List expectedWarnings() { List warnings = new ArrayList<>(expectedWarnings.size()); for (String warning : expectedWarnings) { - int idx = warning.toLowerCase(Locale.ROOT).indexOf(EMULATED_PREFIX); - if (idx >= 0) { - assertTrue("Invalid warning spec: closing delimiter (]) missing: `" + warning + "`", warning.endsWith("]")); - if (forEmulated) { - if (idx + EMULATED_PREFIX.length() < warning.length() - 1) { - warnings.add(warning.substring(idx + EMULATED_PREFIX.length(), warning.length() - 1)); - } - } else if (idx > 0) { - warnings.add(warning.substring(0, idx)); - } // else: no warnings expected for non-emulated - } else { - warnings.add(warning); - } + warnings.add(warning); } return warnings; } diff --git a/x-pack/plugin/esql/compute/build.gradle b/x-pack/plugin/esql/compute/build.gradle index 3e61b9bc5e51c..e5816d0b7c78b 100644 --- a/x-pack/plugin/esql/compute/build.gradle +++ b/x-pack/plugin/esql/compute/build.gradle @@ -400,6 +400,11 @@ tasks.named('stringTemplates').configure { it.outputFile = "org/elasticsearch/compute/data/BooleanVectorFixedBuilder.java" } File stateInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-State.java.st") + template { + it.properties = booleanProperties + it.inputFile = stateInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/BooleanState.java" + } template { it.properties = intProperties it.inputFile = stateInputFile @@ -453,6 +458,11 @@ tasks.named('stringTemplates').configure { it.outputFile = "org/elasticsearch/compute/data/BooleanLookup.java" } File arrayStateInputFile = new File("${projectDir}/src/main/java/org/elasticsearch/compute/aggregation/X-ArrayState.java.st") + template { + it.properties = booleanProperties + it.inputFile = arrayStateInputFile + it.outputFile = "org/elasticsearch/compute/aggregation/BooleanArrayState.java" + } template { it.properties = intProperties it.inputFile = arrayStateInputFile diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java index 1127d4b4ccb72..b3d32a82cc7a9 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/AggregatorImplementer.java @@ -445,6 +445,8 @@ String intermediateStateRowAccess() { private String primitiveStateMethod() { switch (stateType.toString()) { + case "org.elasticsearch.compute.aggregation.BooleanState": + return "booleanValue"; case "org.elasticsearch.compute.aggregation.IntState": return "intValue"; case "org.elasticsearch.compute.aggregation.LongState": @@ -494,6 +496,9 @@ private MethodSpec evaluateFinal() { private void primitiveStateToResult(MethodSpec.Builder builder) { switch (stateType.toString()) { + case "org.elasticsearch.compute.aggregation.BooleanState": + builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantBooleanBlockWith(state.booleanValue(), 1)"); + return; case "org.elasticsearch.compute.aggregation.IntState": builder.addStatement("blocks[offset] = driverContext.blockFactory().newConstantIntBlockWith(state.intValue(), 1)"); return; @@ -531,8 +536,9 @@ private MethodSpec close() { private boolean hasPrimitiveState() { return switch (stateType.toString()) { - case "org.elasticsearch.compute.aggregation.IntState", "org.elasticsearch.compute.aggregation.LongState", - "org.elasticsearch.compute.aggregation.DoubleState", "org.elasticsearch.compute.aggregation.FloatState" -> true; + case "org.elasticsearch.compute.aggregation.BooleanState", "org.elasticsearch.compute.aggregation.IntState", + "org.elasticsearch.compute.aggregation.LongState", "org.elasticsearch.compute.aggregation.DoubleState", + "org.elasticsearch.compute.aggregation.FloatState" -> true; default -> false; }; } diff --git a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java index c9cdcfe42fddd..79df41f304c06 100644 --- a/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java +++ b/x-pack/plugin/esql/compute/gen/src/main/java/org/elasticsearch/compute/gen/GroupingAggregatorImplementer.java @@ -584,8 +584,9 @@ private MethodSpec close() { private boolean hasPrimitiveState() { return switch (stateType.toString()) { - case "org.elasticsearch.compute.aggregation.IntArrayState", "org.elasticsearch.compute.aggregation.LongArrayState", - "org.elasticsearch.compute.aggregation.DoubleArrayState", "org.elasticsearch.compute.aggregation.FloatArrayState" -> true; + case "org.elasticsearch.compute.aggregation.BooleanArrayState", "org.elasticsearch.compute.aggregation.IntArrayState", + "org.elasticsearch.compute.aggregation.LongArrayState", "org.elasticsearch.compute.aggregation.DoubleArrayState", + "org.elasticsearch.compute.aggregation.FloatArrayState" -> true; default -> false; }; } diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanArrayState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanArrayState.java new file mode 100644 index 0000000000000..79f4a88d403c6 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanArrayState.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.common.util.BitArray; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.operator.DriverContext; +import org.elasticsearch.core.Releasables; + +/** + * Aggregator state for an array of booleans. It is created in a mode where it + * won't track the {@code groupId}s that are sent to it and it is the + * responsibility of the caller to only fetch values for {@code groupId}s + * that it has sent using the {@code selected} parameter when building the + * results. This is fine when there are no {@code null} values in the input + * data. But once there are null values in the input data it is + * much more convenient to only send non-null values and + * the tracking built into the grouping code can't track that. In that case + * call {@link #enableGroupIdTracking} to transition the state into a mode + * where it'll track which {@code groupIds} have been written. + *

+ * This class is generated. Do not edit it. + *

+ */ +final class BooleanArrayState extends AbstractArrayState implements GroupingAggregatorState { + private final boolean init; + + private BitArray values; + private int size; + + BooleanArrayState(BigArrays bigArrays, boolean init) { + super(bigArrays); + this.values = new BitArray(1, bigArrays); + this.size = 1; + this.values.set(0, init); + this.init = init; + } + + boolean get(int groupId) { + return values.get(groupId); + } + + boolean getOrDefault(int groupId) { + return groupId < values.size() ? values.get(groupId) : init; + } + + void set(int groupId, boolean value) { + ensureCapacity(groupId); + values.set(groupId, value); + trackGroupId(groupId); + } + + Block toValuesBlock(org.elasticsearch.compute.data.IntVector selected, DriverContext driverContext) { + if (false == trackingGroupIds()) { + try (var builder = driverContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + builder.appendBoolean(i, values.get(selected.getInt(i))); + } + return builder.build().asBlock(); + } + } + try (BooleanBlock.Builder builder = driverContext.blockFactory().newBooleanBlockBuilder(selected.getPositionCount())) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int group = selected.getInt(i); + if (hasValue(group)) { + builder.appendBoolean(values.get(group)); + } else { + builder.appendNull(); + } + } + return builder.build(); + } + } + + private void ensureCapacity(int groupId) { + if (groupId >= size) { + values.fill(size, groupId + 1, init); + size = groupId + 1; + } + } + + /** Extracts an intermediate view of the contents of this state. */ + @Override + public void toIntermediate( + Block[] blocks, + int offset, + IntVector selected, + org.elasticsearch.compute.operator.DriverContext driverContext + ) { + assert blocks.length >= offset + 2; + try ( + var valuesBuilder = driverContext.blockFactory().newBooleanBlockBuilder(selected.getPositionCount()); + var hasValueBuilder = driverContext.blockFactory().newBooleanVectorFixedBuilder(selected.getPositionCount()) + ) { + for (int i = 0; i < selected.getPositionCount(); i++) { + int group = selected.getInt(i); + if (group < values.size()) { + valuesBuilder.appendBoolean(values.get(group)); + } else { + valuesBuilder.appendBoolean(false); // TODO can we just use null? + } + hasValueBuilder.appendBoolean(i, hasValue(group)); + } + blocks[offset + 0] = valuesBuilder.build(); + blocks[offset + 1] = hasValueBuilder.build().asBlock(); + } + } + + @Override + public void close() { + Releasables.close(values, super::close); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanState.java b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanState.java new file mode 100644 index 0000000000000..7d225c7c06a72 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated-src/org/elasticsearch/compute/aggregation/BooleanState.java @@ -0,0 +1,55 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * Aggregator state for a single boolean. + * This class is generated. Do not edit it. + */ +final class BooleanState implements AggregatorState { + private boolean value; + private boolean seen; + + BooleanState() { + this(false); + } + + BooleanState(boolean init) { + this.value = init; + } + + boolean booleanValue() { + return value; + } + + void booleanValue(boolean value) { + this.value = value; + } + + boolean seen() { + return seen; + } + + void seen(boolean seen) { + this.seen = seen; + } + + /** Extracts an intermediate view of the contents of this state. */ + @Override + public void toIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + assert blocks.length >= offset + 2; + blocks[offset + 0] = driverContext.blockFactory().newConstantBooleanBlockWith(value, 1); + blocks[offset + 1] = driverContext.blockFactory().newConstantBooleanBlockWith(seen, 1); + } + + @Override + public void close() {} +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanAggregatorFunction.java new file mode 100644 index 0000000000000..2ffbcfc2d9458 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanAggregatorFunction.java @@ -0,0 +1,136 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link MaxBooleanAggregator}. + * This class is generated. Do not edit it. + */ +public final class MaxBooleanAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("max", ElementType.BOOLEAN), + new IntermediateStateDesc("seen", ElementType.BOOLEAN) ); + + private final DriverContext driverContext; + + private final BooleanState state; + + private final List channels; + + public MaxBooleanAggregatorFunction(DriverContext driverContext, List channels, + BooleanState state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static MaxBooleanAggregatorFunction create(DriverContext driverContext, + List channels) { + return new MaxBooleanAggregatorFunction(driverContext, channels, new BooleanState(MaxBooleanAggregator.init())); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page) { + BooleanBlock block = page.getBlock(channels.get(0)); + BooleanVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + } + + private void addRawVector(BooleanVector vector) { + state.seen(true); + for (int i = 0; i < vector.getPositionCount(); i++) { + state.booleanValue(MaxBooleanAggregator.combine(state.booleanValue(), vector.getBoolean(i))); + } + } + + private void addRawBlock(BooleanBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + state.seen(true); + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + state.booleanValue(MaxBooleanAggregator.combine(state.booleanValue(), block.getBoolean(i))); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BooleanVector max = ((BooleanBlock) maxUncast).asVector(); + assert max.getPositionCount() == 1; + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert seen.getPositionCount() == 1; + if (seen.getBoolean(0)) { + state.booleanValue(MaxBooleanAggregator.combine(state.booleanValue(), max.getBoolean(0))); + state.seen(true); + } + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + if (state.seen() == false) { + blocks[offset] = driverContext.blockFactory().newConstantNullBlock(1); + return; + } + blocks[offset] = driverContext.blockFactory().newConstantBooleanBlockWith(state.booleanValue(), 1); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..e5bbf63ddee07 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanAggregatorFunctionSupplier.java @@ -0,0 +1,38 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link MaxBooleanAggregator}. + * This class is generated. Do not edit it. + */ +public final class MaxBooleanAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + public MaxBooleanAggregatorFunctionSupplier(List channels) { + this.channels = channels; + } + + @Override + public MaxBooleanAggregatorFunction aggregator(DriverContext driverContext) { + return MaxBooleanAggregatorFunction.create(driverContext, channels); + } + + @Override + public MaxBooleanGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { + return MaxBooleanGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "max of booleans"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..b72ff8354cb12 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MaxBooleanGroupingAggregatorFunction.java @@ -0,0 +1,204 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link MaxBooleanAggregator}. + * This class is generated. Do not edit it. + */ +public final class MaxBooleanGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("max", ElementType.BOOLEAN), + new IntermediateStateDesc("seen", ElementType.BOOLEAN) ); + + private final BooleanArrayState state; + + private final List channels; + + private final DriverContext driverContext; + + public MaxBooleanGroupingAggregatorFunction(List channels, BooleanArrayState state, + DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static MaxBooleanGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new MaxBooleanGroupingAggregatorFunction(channels, new BooleanArrayState(driverContext.bigArrays(), MaxBooleanAggregator.init()), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + BooleanBlock valuesBlock = page.getBlock(channels.get(0)); + BooleanVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = Math.toIntExact(groups.getInt(groupPosition)); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, MaxBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(v))); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, BooleanVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = Math.toIntExact(groups.getInt(groupPosition)); + state.set(groupId, MaxBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(groupPosition + positionOffset))); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, BooleanBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = Math.toIntExact(groups.getInt(g)); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, MaxBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(v))); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, BooleanVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = Math.toIntExact(groups.getInt(g)); + state.set(groupId, MaxBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(groupPosition + positionOffset))); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block maxUncast = page.getBlock(channels.get(0)); + if (maxUncast.areAllValuesNull()) { + return; + } + BooleanVector max = ((BooleanBlock) maxUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert max.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = Math.toIntExact(groups.getInt(groupPosition)); + MaxBooleanAggregator.combineIntermediate(state, groupId, max.getBoolean(groupPosition + positionOffset), seen.getBoolean(groupPosition + positionOffset)); + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + BooleanArrayState inState = ((MaxBooleanGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + if (inState.hasValue(position)) { + state.set(groupId, MaxBooleanAggregator.combine(state.getOrDefault(groupId), inState.get(position))); + } + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = state.toValuesBlock(selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanAggregatorFunction.java new file mode 100644 index 0000000000000..101a6c7f9169a --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanAggregatorFunction.java @@ -0,0 +1,136 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunction} implementation for {@link MinBooleanAggregator}. + * This class is generated. Do not edit it. + */ +public final class MinBooleanAggregatorFunction implements AggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("min", ElementType.BOOLEAN), + new IntermediateStateDesc("seen", ElementType.BOOLEAN) ); + + private final DriverContext driverContext; + + private final BooleanState state; + + private final List channels; + + public MinBooleanAggregatorFunction(DriverContext driverContext, List channels, + BooleanState state) { + this.driverContext = driverContext; + this.channels = channels; + this.state = state; + } + + public static MinBooleanAggregatorFunction create(DriverContext driverContext, + List channels) { + return new MinBooleanAggregatorFunction(driverContext, channels, new BooleanState(MinBooleanAggregator.init())); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public void addRawInput(Page page) { + BooleanBlock block = page.getBlock(channels.get(0)); + BooleanVector vector = block.asVector(); + if (vector != null) { + addRawVector(vector); + } else { + addRawBlock(block); + } + } + + private void addRawVector(BooleanVector vector) { + state.seen(true); + for (int i = 0; i < vector.getPositionCount(); i++) { + state.booleanValue(MinBooleanAggregator.combine(state.booleanValue(), vector.getBoolean(i))); + } + } + + private void addRawBlock(BooleanBlock block) { + for (int p = 0; p < block.getPositionCount(); p++) { + if (block.isNull(p)) { + continue; + } + state.seen(true); + int start = block.getFirstValueIndex(p); + int end = start + block.getValueCount(p); + for (int i = start; i < end; i++) { + state.booleanValue(MinBooleanAggregator.combine(state.booleanValue(), block.getBoolean(i))); + } + } + } + + @Override + public void addIntermediateInput(Page page) { + assert channels.size() == intermediateBlockCount(); + assert page.getBlockCount() >= channels.get(0) + intermediateStateDesc().size(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + BooleanVector min = ((BooleanBlock) minUncast).asVector(); + assert min.getPositionCount() == 1; + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert seen.getPositionCount() == 1; + if (seen.getBoolean(0)) { + state.booleanValue(MinBooleanAggregator.combine(state.booleanValue(), min.getBoolean(0))); + state.seen(true); + } + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, DriverContext driverContext) { + state.toIntermediate(blocks, offset, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, DriverContext driverContext) { + if (state.seen() == false) { + blocks[offset] = driverContext.blockFactory().newConstantNullBlock(1); + return; + } + blocks[offset] = driverContext.blockFactory().newConstantBooleanBlockWith(state.booleanValue(), 1); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanAggregatorFunctionSupplier.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanAggregatorFunctionSupplier.java new file mode 100644 index 0000000000000..f66dc6e67e0fd --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanAggregatorFunctionSupplier.java @@ -0,0 +1,38 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.util.List; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link AggregatorFunctionSupplier} implementation for {@link MinBooleanAggregator}. + * This class is generated. Do not edit it. + */ +public final class MinBooleanAggregatorFunctionSupplier implements AggregatorFunctionSupplier { + private final List channels; + + public MinBooleanAggregatorFunctionSupplier(List channels) { + this.channels = channels; + } + + @Override + public MinBooleanAggregatorFunction aggregator(DriverContext driverContext) { + return MinBooleanAggregatorFunction.create(driverContext, channels); + } + + @Override + public MinBooleanGroupingAggregatorFunction groupingAggregator(DriverContext driverContext) { + return MinBooleanGroupingAggregatorFunction.create(channels, driverContext); + } + + @Override + public String describe() { + return "min of booleans"; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanGroupingAggregatorFunction.java b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanGroupingAggregatorFunction.java new file mode 100644 index 0000000000000..6175cad3924e2 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/generated/org/elasticsearch/compute/aggregation/MinBooleanGroupingAggregatorFunction.java @@ -0,0 +1,206 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License +// 2.0; you may not use this file except in compliance with the Elastic License +// 2.0. +package org.elasticsearch.compute.aggregation; + +import java.lang.Integer; +import java.lang.Override; +import java.lang.String; +import java.lang.StringBuilder; +import java.util.List; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.data.BooleanVector; +import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.data.IntBlock; +import org.elasticsearch.compute.data.IntVector; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverContext; + +/** + * {@link GroupingAggregatorFunction} implementation for {@link MinBooleanAggregator}. + * This class is generated. Do not edit it. + */ +public final class MinBooleanGroupingAggregatorFunction implements GroupingAggregatorFunction { + private static final List INTERMEDIATE_STATE_DESC = List.of( + new IntermediateStateDesc("min", ElementType.BOOLEAN), + new IntermediateStateDesc("seen", ElementType.BOOLEAN) ); + + private final BooleanArrayState state; + + private final List channels; + + private final DriverContext driverContext; + + public MinBooleanGroupingAggregatorFunction(List channels, BooleanArrayState state, + DriverContext driverContext) { + this.channels = channels; + this.state = state; + this.driverContext = driverContext; + } + + public static MinBooleanGroupingAggregatorFunction create(List channels, + DriverContext driverContext) { + return new MinBooleanGroupingAggregatorFunction(channels, new BooleanArrayState(driverContext.bigArrays(), MinBooleanAggregator.init()), driverContext); + } + + public static List intermediateStateDesc() { + return INTERMEDIATE_STATE_DESC; + } + + @Override + public int intermediateBlockCount() { + return INTERMEDIATE_STATE_DESC.size(); + } + + @Override + public GroupingAggregatorFunction.AddInput prepareProcessPage(SeenGroupIds seenGroupIds, + Page page) { + BooleanBlock valuesBlock = page.getBlock(channels.get(0)); + BooleanVector valuesVector = valuesBlock.asVector(); + if (valuesVector == null) { + if (valuesBlock.mayHaveNulls()) { + state.enableGroupIdTracking(seenGroupIds); + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesBlock); + } + }; + } + return new GroupingAggregatorFunction.AddInput() { + @Override + public void add(int positionOffset, IntBlock groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + + @Override + public void add(int positionOffset, IntVector groupIds) { + addRawInput(positionOffset, groupIds, valuesVector); + } + }; + } + + private void addRawInput(int positionOffset, IntVector groups, BooleanBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = Math.toIntExact(groups.getInt(groupPosition)); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(v))); + } + } + } + + private void addRawInput(int positionOffset, IntVector groups, BooleanVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = Math.toIntExact(groups.getInt(groupPosition)); + state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(groupPosition + positionOffset))); + } + } + + private void addRawInput(int positionOffset, IntBlock groups, BooleanBlock values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = Math.toIntExact(groups.getInt(g)); + if (values.isNull(groupPosition + positionOffset)) { + continue; + } + int valuesStart = values.getFirstValueIndex(groupPosition + positionOffset); + int valuesEnd = valuesStart + values.getValueCount(groupPosition + positionOffset); + for (int v = valuesStart; v < valuesEnd; v++) { + state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(v))); + } + } + } + } + + private void addRawInput(int positionOffset, IntBlock groups, BooleanVector values) { + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + if (groups.isNull(groupPosition)) { + continue; + } + int groupStart = groups.getFirstValueIndex(groupPosition); + int groupEnd = groupStart + groups.getValueCount(groupPosition); + for (int g = groupStart; g < groupEnd; g++) { + int groupId = Math.toIntExact(groups.getInt(g)); + state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), values.getBoolean(groupPosition + positionOffset))); + } + } + } + + @Override + public void addIntermediateInput(int positionOffset, IntVector groups, Page page) { + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + assert channels.size() == intermediateBlockCount(); + Block minUncast = page.getBlock(channels.get(0)); + if (minUncast.areAllValuesNull()) { + return; + } + BooleanVector min = ((BooleanBlock) minUncast).asVector(); + Block seenUncast = page.getBlock(channels.get(1)); + if (seenUncast.areAllValuesNull()) { + return; + } + BooleanVector seen = ((BooleanBlock) seenUncast).asVector(); + assert min.getPositionCount() == seen.getPositionCount(); + for (int groupPosition = 0; groupPosition < groups.getPositionCount(); groupPosition++) { + int groupId = Math.toIntExact(groups.getInt(groupPosition)); + if (seen.getBoolean(groupPosition + positionOffset)) { + state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), min.getBoolean(groupPosition + positionOffset))); + } + } + } + + @Override + public void addIntermediateRowInput(int groupId, GroupingAggregatorFunction input, int position) { + if (input.getClass() != getClass()) { + throw new IllegalArgumentException("expected " + getClass() + "; got " + input.getClass()); + } + BooleanArrayState inState = ((MinBooleanGroupingAggregatorFunction) input).state; + state.enableGroupIdTracking(new SeenGroupIds.Empty()); + if (inState.hasValue(position)) { + state.set(groupId, MinBooleanAggregator.combine(state.getOrDefault(groupId), inState.get(position))); + } + } + + @Override + public void evaluateIntermediate(Block[] blocks, int offset, IntVector selected) { + state.toIntermediate(blocks, offset, selected, driverContext); + } + + @Override + public void evaluateFinal(Block[] blocks, int offset, IntVector selected, + DriverContext driverContext) { + blocks[offset] = state.toValuesBlock(selected, driverContext); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(getClass().getSimpleName()).append("["); + sb.append("channels=").append(channels); + sb.append("]"); + return sb.toString(); + } + + @Override + public void close() { + state.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxBooleanAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxBooleanAggregator.java new file mode 100644 index 0000000000000..79d0cd4d7492f --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxBooleanAggregator.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; + +@Aggregator({ @IntermediateState(name = "max", type = "BOOLEAN"), @IntermediateState(name = "seen", type = "BOOLEAN") }) +@GroupingAggregator +class MaxBooleanAggregator { + + public static boolean init() { + return false; + } + + public static boolean combine(boolean current, boolean v) { + return current || v; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxDoubleAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxDoubleAggregator.java index ee6555c4af67d..f0804278e5002 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxDoubleAggregator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MaxDoubleAggregator.java @@ -16,7 +16,7 @@ class MaxDoubleAggregator { public static double init() { - return Double.MIN_VALUE; + return -Double.MAX_VALUE; } public static double combine(double current, double v) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinBooleanAggregator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinBooleanAggregator.java new file mode 100644 index 0000000000000..372a5d988688f --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/MinBooleanAggregator.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.ann.Aggregator; +import org.elasticsearch.compute.ann.GroupingAggregator; +import org.elasticsearch.compute.ann.IntermediateState; + +@Aggregator({ @IntermediateState(name = "min", type = "BOOLEAN"), @IntermediateState(name = "seen", type = "BOOLEAN") }) +@GroupingAggregator +class MinBooleanAggregator { + + public static boolean init() { + return true; + } + + public static boolean combine(boolean current, boolean v) { + return current && v; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ArrayState.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ArrayState.java.st index 18686928f14a8..10dbd9f423725 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ArrayState.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-ArrayState.java.st @@ -8,7 +8,11 @@ package org.elasticsearch.compute.aggregation; import org.elasticsearch.common.util.BigArrays; +$if(boolean)$ +import org.elasticsearch.common.util.BitArray; +$else$ import org.elasticsearch.common.util.$Type$Array; +$endif$ import org.elasticsearch.compute.data.Block; $if(long)$ import org.elasticsearch.compute.data.IntVector; @@ -17,7 +21,7 @@ import org.elasticsearch.compute.data.$Type$Block; $if(int)$ import org.elasticsearch.compute.data.$Type$Vector; $endif$ -$if(double||float)$ +$if(boolean||double||float)$ import org.elasticsearch.compute.data.IntVector; $endif$ import org.elasticsearch.compute.operator.DriverContext; @@ -41,11 +45,22 @@ import org.elasticsearch.core.Releasables; final class $Type$ArrayState extends AbstractArrayState implements GroupingAggregatorState { private final $type$ init; +$if(boolean)$ + private BitArray values; + private int size; + +$else$ private $Type$Array values; +$endif$ $Type$ArrayState(BigArrays bigArrays, $type$ init) { super(bigArrays); +$if(boolean)$ + this.values = new BitArray(1, bigArrays); + this.size = 1; +$else$ this.values = bigArrays.new$Type$Array(1, false); +$endif$ this.values.set(0, init); this.init = init; } @@ -95,11 +110,18 @@ $endif$ } private void ensureCapacity(int groupId) { +$if(boolean)$ + if (groupId >= size) { + values.fill(size, groupId + 1, init); + size = groupId + 1; + } +$else$ if (groupId >= values.size()) { long prevSize = values.size(); values = bigArrays.grow(values, groupId + 1); values.fill(prevSize, values.size(), init); } +$endif$ } /** Extracts an intermediate view of the contents of this state. */ @@ -120,7 +142,7 @@ $endif$ if (group < values.size()) { valuesBuilder.append$Type$(values.get(group)); } else { - valuesBuilder.append$Type$(0); // TODO can we just use null? + valuesBuilder.append$Type$($if(boolean)$false$else$0$endif$); // TODO can we just use null? } hasValueBuilder.appendBoolean(i, hasValue(group)); } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-State.java.st b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-State.java.st index 427d1a0c312cc..2d2d706c9454f 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-State.java.st +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/aggregation/X-State.java.st @@ -19,7 +19,11 @@ final class $Type$State implements AggregatorState { private boolean seen; $Type$State() { +$if(boolean)$ + this(false); +$else$ this(0); +$endif$ } $Type$State($type$ init) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AsyncOperator.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AsyncOperator.java index 061cefc86bed0..0fed88370a144 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AsyncOperator.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/AsyncOperator.java @@ -21,13 +21,11 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.index.seqno.LocalCheckpointTracker; import org.elasticsearch.index.seqno.SequenceNumbers; -import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.xcontent.XContentBuilder; import java.io.IOException; import java.util.Map; import java.util.Objects; -import java.util.concurrent.atomic.AtomicReference; import java.util.concurrent.atomic.LongAdder; /** @@ -40,7 +38,7 @@ public abstract class AsyncOperator implements Operator { private volatile SubscribableListener blockedFuture; private final Map buffers = ConcurrentCollections.newConcurrentMap(); - private final AtomicReference failure = new AtomicReference<>(); + private final FailureCollector failureCollector = new FailureCollector(); private final DriverContext driverContext; private final int maxOutstandingRequests; @@ -77,7 +75,7 @@ public boolean needsInput() { @Override public void addInput(Page input) { - if (failure.get() != null) { + if (failureCollector.hasFailure()) { input.releaseBlocks(); return; } @@ -90,7 +88,7 @@ public void addInput(Page input) { onSeqNoCompleted(seqNo); }, e -> { releasePageOnAnyThread(input); - onFailure(e); + failureCollector.unwrapAndCollect(e); onSeqNoCompleted(seqNo); }); final long startNanos = System.nanoTime(); @@ -121,31 +119,12 @@ private void releasePageOnAnyThread(Page page) { protected abstract void doClose(); - private void onFailure(Exception e) { - failure.getAndUpdate(first -> { - if (first == null) { - return e; - } - // ignore subsequent TaskCancelledException exceptions as they don't provide useful info. - if (ExceptionsHelper.unwrap(e, TaskCancelledException.class) != null) { - return first; - } - if (ExceptionsHelper.unwrap(first, TaskCancelledException.class) != null) { - return e; - } - if (ExceptionsHelper.unwrapCause(first) != ExceptionsHelper.unwrapCause(e)) { - first.addSuppressed(e); - } - return first; - }); - } - private void onSeqNoCompleted(long seqNo) { checkpoint.markSeqNoAsProcessed(seqNo); if (checkpoint.getPersistedCheckpoint() < checkpoint.getProcessedCheckpoint()) { notifyIfBlocked(); } - if (closed || failure.get() != null) { + if (closed || failureCollector.hasFailure()) { discardPages(); } } @@ -164,7 +143,7 @@ private void notifyIfBlocked() { } private void checkFailure() { - Exception e = failure.get(); + Exception e = failureCollector.getFailure(); if (e != null) { discardPages(); throw ExceptionsHelper.convertToElastic(e); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java index 5de017fbd279e..b427a36566f11 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/DriverRunner.java @@ -7,14 +7,11 @@ package org.elasticsearch.compute.operator; -import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.util.concurrent.CountDown; import org.elasticsearch.common.util.concurrent.ThreadContext; -import org.elasticsearch.tasks.TaskCancelledException; import java.util.List; -import java.util.concurrent.atomic.AtomicReference; /** * Run a set of drivers to completion. @@ -35,8 +32,8 @@ public DriverRunner(ThreadContext threadContext) { * Run all drivers to completion asynchronously. */ public void runToCompletion(List drivers, ActionListener listener) { - AtomicReference failure = new AtomicReference<>(); var responseHeadersCollector = new ResponseHeadersCollector(threadContext); + var failure = new FailureCollector(); CountDown counter = new CountDown(drivers.size()); for (int i = 0; i < drivers.size(); i++) { Driver driver = drivers.get(i); @@ -48,23 +45,7 @@ public void onResponse(Void unused) { @Override public void onFailure(Exception e) { - failure.getAndUpdate(first -> { - if (first == null) { - return e; - } - if (ExceptionsHelper.unwrap(e, TaskCancelledException.class) != null) { - return first; - } else { - if (ExceptionsHelper.unwrap(first, TaskCancelledException.class) != null) { - return e; - } else { - if (first != e) { - first.addSuppressed(e); - } - return first; - } - } - }); + failure.unwrapAndCollect(e); for (Driver d : drivers) { if (driver != d) { d.cancel("Driver [" + driver.sessionId() + "] was cancelled or failed"); @@ -77,7 +58,7 @@ private void done() { responseHeadersCollector.collect(); if (counter.countDown()) { responseHeadersCollector.finish(); - Exception error = failure.get(); + Exception error = failure.getFailure(); if (error != null) { listener.onFailure(error); } else { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FailureCollector.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FailureCollector.java new file mode 100644 index 0000000000000..99edab038af31 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FailureCollector.java @@ -0,0 +1,112 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.transport.TransportException; + +import java.util.List; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * {@code FailureCollector} is responsible for collecting exceptions that occur in the compute engine. + * The collected exceptions are categorized into task-cancelled and non-task-cancelled exceptions. + * To limit memory usage, this class collects only the first 10 exceptions in each category by default. + * When returning the accumulated failure to the caller, this class prefers non-task-cancelled exceptions + * over task-cancelled ones as they are more useful for diagnosing issues. + */ +public final class FailureCollector { + private final Queue cancelledExceptions = ConcurrentCollections.newQueue(); + private final AtomicInteger cancelledExceptionsCount = new AtomicInteger(); + + private final Queue nonCancelledExceptions = ConcurrentCollections.newQueue(); + private final AtomicInteger nonCancelledExceptionsCount = new AtomicInteger(); + + private final int maxExceptions; + private volatile boolean hasFailure = false; + private Exception finalFailure = null; + + public FailureCollector() { + this(10); + } + + public FailureCollector(int maxExceptions) { + if (maxExceptions <= 0) { + throw new IllegalArgumentException("maxExceptions must be at least one"); + } + this.maxExceptions = maxExceptions; + } + + public void unwrapAndCollect(Exception originEx) { + final Exception e = originEx instanceof TransportException + ? (originEx.getCause() instanceof Exception cause ? cause : new ElasticsearchException(originEx.getCause())) + : originEx; + if (ExceptionsHelper.unwrap(e, TaskCancelledException.class) != null) { + if (cancelledExceptionsCount.incrementAndGet() <= maxExceptions) { + cancelledExceptions.add(e); + } + } else { + if (nonCancelledExceptionsCount.incrementAndGet() <= maxExceptions) { + nonCancelledExceptions.add(e); + } + } + hasFailure = true; + } + + /** + * @return {@code true} if any failure has been collected, {@code false} otherwise + */ + public boolean hasFailure() { + return hasFailure; + } + + /** + * Returns the accumulated failure, preferring non-task-cancelled exceptions over task-cancelled ones. + * Once this method builds the failure, incoming failures are discarded. + * + * @return the accumulated failure, or {@code null} if no failure has been collected + */ + public Exception getFailure() { + if (hasFailure == false) { + return null; + } + synchronized (this) { + if (finalFailure == null) { + finalFailure = buildFailure(); + } + return finalFailure; + } + } + + private Exception buildFailure() { + assert hasFailure; + assert Thread.holdsLock(this); + int total = 0; + Exception first = null; + for (var exceptions : List.of(nonCancelledExceptions, cancelledExceptions)) { + for (Exception e : exceptions) { + if (first == null) { + first = e; + total++; + } else if (first != e) { + first.addSuppressed(e); + total++; + } + if (total >= maxExceptions) { + return first; + } + } + } + assert first != null; + return first; + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java index f647f4fba0225..a365a655370a2 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java @@ -250,21 +250,20 @@ public boolean isForceExecution() { protected void doRun() { assert Transports.assertNotTransportThread("reaping inactive exchanges can be expensive"); assert ThreadPool.assertNotScheduleThread("reaping inactive exchanges can be expensive"); + logger.debug("start removing inactive sinks"); final long nowInMillis = threadPool.relativeTimeInMillis(); for (Map.Entry e : sinks.entrySet()) { ExchangeSinkHandler sink = e.getValue(); if (sink.hasData() && sink.hasListeners()) { continue; } - long elapsed = nowInMillis - sink.lastUpdatedTimeInMillis(); - if (elapsed > keepAlive.millis()) { + long elapsedInMillis = nowInMillis - sink.lastUpdatedTimeInMillis(); + if (elapsedInMillis > keepAlive.millis()) { + TimeValue elapsedTime = TimeValue.timeValueMillis(elapsedInMillis); + logger.debug("removed sink {} inactive for {}", e.getKey(), elapsedTime); finishSinkHandler( e.getKey(), - new ElasticsearchTimeoutException( - "Exchange sink {} has been inactive for {}", - e.getKey(), - TimeValue.timeValueMillis(elapsed) - ) + new ElasticsearchTimeoutException("Exchange sink {} has been inactive for {}", e.getKey(), elapsedTime) ); } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java index adce8d8a88407..77b535949eb9d 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java @@ -7,21 +7,18 @@ package org.elasticsearch.compute.operator.exchange; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.RefCountingListener; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.FailureCollector; import org.elasticsearch.core.Releasable; -import org.elasticsearch.tasks.TaskCancelledException; -import org.elasticsearch.transport.TransportException; import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicInteger; -import java.util.concurrent.atomic.AtomicReference; /** * An {@link ExchangeSourceHandler} asynchronously fetches pages and status from multiple {@link RemoteSink}s @@ -37,7 +34,7 @@ public final class ExchangeSourceHandler { private final PendingInstances outstandingSinks; private final PendingInstances outstandingSources; - private final AtomicReference failure = new AtomicReference<>(); + private final FailureCollector failure = new FailureCollector(); public ExchangeSourceHandler(int maxBufferSize, Executor fetchExecutor) { this.buffer = new ExchangeBuffer(maxBufferSize); @@ -54,7 +51,7 @@ private class ExchangeSourceImpl implements ExchangeSource { } private void checkFailure() { - Exception e = failure.get(); + Exception e = failure.getFailure(); if (e != null) { throw ExceptionsHelper.convertToElastic(e); } @@ -172,7 +169,7 @@ void fetchPage() { while (loopControl.isRunning()) { loopControl.exiting(); // finish other sinks if one of them failed or source no longer need pages. - boolean toFinishSinks = buffer.noMoreInputs() || failure.get() != null; + boolean toFinishSinks = buffer.noMoreInputs() || failure.hasFailure(); remoteSink.fetchPageAsync(toFinishSinks, ActionListener.wrap(resp -> { Page page = resp.takePage(); if (page != null) { @@ -199,26 +196,8 @@ void fetchPage() { loopControl.exited(); } - void onSinkFailed(Exception originEx) { - final Exception e = originEx instanceof TransportException - ? (originEx.getCause() instanceof Exception cause ? cause : new ElasticsearchException(originEx.getCause())) - : originEx; - failure.getAndUpdate(first -> { - if (first == null) { - return e; - } - // ignore subsequent TaskCancelledException exceptions as they don't provide useful info. - if (ExceptionsHelper.unwrap(e, TaskCancelledException.class) != null) { - return first; - } - if (ExceptionsHelper.unwrap(first, TaskCancelledException.class) != null) { - return e; - } - if (ExceptionsHelper.unwrapCause(first) != ExceptionsHelper.unwrapCause(e)) { - first.addSuppressed(e); - } - return first; - }); + void onSinkFailed(Exception e) { + failure.unwrapAndCollect(e); buffer.waitForReading().onResponse(null); // resume the Driver if it is being blocked on reading onSinkComplete(); } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MaxBooleanAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MaxBooleanAggregatorFunctionTests.java new file mode 100644 index 0000000000000..11119aade12ff --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MaxBooleanAggregatorFunctionTests.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.operator.SequenceBooleanBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; + +import java.util.Comparator; +import java.util.List; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.equalTo; + +public class MaxBooleanAggregatorFunctionTests extends AggregatorFunctionTestCase { + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + return new SequenceBooleanBlockSourceOperator(blockFactory, IntStream.range(0, size).mapToObj(l -> randomBoolean()).toList()); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { + return new MaxBooleanAggregatorFunctionSupplier(inputChannels); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "max of booleans"; + } + + @Override + public void assertSimpleOutput(List input, Block result) { + Boolean max = input.stream().flatMap(b -> allBooleans(b)).max(Comparator.naturalOrder()).get(); + assertThat(((BooleanBlock) result).getBoolean(0), equalTo(max)); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MinBooleanAggregatorFunctionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MinBooleanAggregatorFunctionTests.java new file mode 100644 index 0000000000000..74cdca31da34b --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/aggregation/MinBooleanAggregatorFunctionTests.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.aggregation; + +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.BlockFactory; +import org.elasticsearch.compute.data.BooleanBlock; +import org.elasticsearch.compute.operator.SequenceBooleanBlockSourceOperator; +import org.elasticsearch.compute.operator.SourceOperator; + +import java.util.Comparator; +import java.util.List; +import java.util.stream.IntStream; + +import static org.hamcrest.Matchers.equalTo; + +public class MinBooleanAggregatorFunctionTests extends AggregatorFunctionTestCase { + @Override + protected SourceOperator simpleInput(BlockFactory blockFactory, int size) { + return new SequenceBooleanBlockSourceOperator(blockFactory, IntStream.range(0, size).mapToObj(l -> randomBoolean()).toList()); + } + + @Override + protected AggregatorFunctionSupplier aggregatorFunction(List inputChannels) { + return new MinBooleanAggregatorFunctionSupplier(inputChannels); + } + + @Override + protected String expectedDescriptionOfAggregator() { + return "min of booleans"; + } + + @Override + public void assertSimpleOutput(List input, Block result) { + Boolean min = input.stream().flatMap(b -> allBooleans(b)).min(Comparator.naturalOrder()).get(); + assertThat(((BooleanBlock) result).getBoolean(0), equalTo(min)); + } +} diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValueSourceReaderTypeConversionTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValueSourceReaderTypeConversionTests.java index 66bcf2a57e393..09f63e9fa45bb 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValueSourceReaderTypeConversionTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/lucene/ValueSourceReaderTypeConversionTests.java @@ -1687,12 +1687,13 @@ public StoredFieldsSpec rowStrideStoredFieldSpec() { @Override public boolean supportsOrdinals() { - return delegate.supportsOrdinals(); + // Fields with mismatching types cannot use ordinals for uniqueness determination, but must convert the values first + return false; } @Override - public SortedSetDocValues ordinals(LeafReaderContext context) throws IOException { - return delegate.ordinals(context); + public SortedSetDocValues ordinals(LeafReaderContext context) { + throw new IllegalArgumentException("Ordinals are not supported for type conversion"); } @Override diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FailureCollectorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FailureCollectorTests.java new file mode 100644 index 0000000000000..d5fa0a1eaecc9 --- /dev/null +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FailureCollectorTests.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute.operator; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.RemoteTransportException; +import org.hamcrest.Matchers; + +import java.io.IOException; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.lessThan; + +public class FailureCollectorTests extends ESTestCase { + + public void testCollect() throws Exception { + int maxExceptions = between(1, 100); + FailureCollector collector = new FailureCollector(maxExceptions); + List cancelledExceptions = List.of( + new TaskCancelledException("user request"), + new TaskCancelledException("cross "), + new TaskCancelledException("on failure") + ); + List nonCancelledExceptions = List.of( + new IOException("i/o simulated"), + new IOException("disk broken"), + new CircuitBreakingException("low memory", CircuitBreaker.Durability.TRANSIENT), + new CircuitBreakingException("over limit", CircuitBreaker.Durability.TRANSIENT) + ); + List failures = Stream.concat( + IntStream.range(0, between(1, 500)).mapToObj(n -> randomFrom(cancelledExceptions)), + IntStream.range(0, between(1, 500)).mapToObj(n -> randomFrom(nonCancelledExceptions)) + ).collect(Collectors.toList()); + Randomness.shuffle(failures); + Queue queue = new ConcurrentLinkedQueue<>(failures); + Thread[] threads = new Thread[between(1, 4)]; + CyclicBarrier carrier = new CyclicBarrier(threads.length); + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(() -> { + try { + carrier.await(10, TimeUnit.SECONDS); + } catch (Exception e) { + throw new AssertionError(e); + } + Exception ex; + while ((ex = queue.poll()) != null) { + if (randomBoolean()) { + collector.unwrapAndCollect(ex); + } else { + collector.unwrapAndCollect(new RemoteTransportException("disconnect", ex)); + } + if (randomBoolean()) { + assertTrue(collector.hasFailure()); + } + } + }); + threads[i].start(); + } + for (Thread thread : threads) { + thread.join(); + } + assertTrue(collector.hasFailure()); + Exception failure = collector.getFailure(); + assertNotNull(failure); + assertThat(failure, Matchers.in(nonCancelledExceptions)); + assertThat(failure.getSuppressed().length, lessThan(maxExceptions)); + } + + public void testEmpty() { + FailureCollector collector = new FailureCollector(5); + assertFalse(collector.hasFailure()); + assertNull(collector.getFailure()); + } +} diff --git a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java index e25eb84023867..e650f0815f964 100644 --- a/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java +++ b/x-pack/plugin/esql/qa/server/src/main/java/org/elasticsearch/xpack/esql/qa/rest/EsqlSpecTestCase.java @@ -204,11 +204,7 @@ protected final void doTest() throws Throwable { builder.tables(tables()); } - Map answer = runEsql( - builder.query(testCase.query), - testCase.expectedWarnings(false), - testCase.expectedWarningsRegex() - ); + Map answer = runEsql(builder.query(testCase.query), testCase.expectedWarnings(), testCase.expectedWarningsRegex()); var expectedColumnsWithValues = loadCsvSpecValues(testCase.expectedResults); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestUtils.java index d88d7f9b9448f..3b3e12978ae04 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestUtils.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/CsvTestUtils.java @@ -10,7 +10,6 @@ import org.apache.lucene.sandbox.document.HalfFloatPoint; import org.apache.lucene.util.BytesRef; import org.elasticsearch.Version; -import org.elasticsearch.common.Strings; import org.elasticsearch.common.breaker.NoopCircuitBreaker; import org.elasticsearch.common.network.InetAddresses; import org.elasticsearch.common.time.DateFormatters; @@ -332,15 +331,15 @@ public static ExpectedResults loadCsvSpecValues(String csv) { columnTypes = new ArrayList<>(header.length); for (String c : header) { - String[] nameWithType = Strings.split(c, ":"); - if (nameWithType == null || nameWithType.length != 2) { + String[] nameWithType = escapeTypecast(c).split(":"); + if (nameWithType.length != 2) { throw new IllegalArgumentException("Invalid CSV header " + c); } - String typeName = nameWithType[1].trim(); - if (typeName.length() == 0) { - throw new IllegalArgumentException("A type is always expected in the csv file; found " + nameWithType); + String typeName = unescapeTypecast(nameWithType[1]).trim(); + if (typeName.isEmpty()) { + throw new IllegalArgumentException("A type is always expected in the csv file; found " + Arrays.toString(nameWithType)); } - String name = nameWithType[0].trim(); + String name = unescapeTypecast(nameWithType[0]).trim(); columnNames.add(name); Type type = Type.asType(typeName); if (type == null) { @@ -398,6 +397,16 @@ public static ExpectedResults loadCsvSpecValues(String csv) { } } + private static final String TYPECAST_SPACER = "__TYPECAST__"; + + private static String escapeTypecast(String typecast) { + return typecast.replace("::", TYPECAST_SPACER); + } + + private static String unescapeTypecast(String typecast) { + return typecast.replace(TYPECAST_SPACER, "::"); + } + public enum Type { INTEGER(Integer::parseInt, Integer.class), LONG(Long::parseLong, Long.class), diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java index d7e067658267f..2bf3baf845010 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java @@ -33,7 +33,6 @@ import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.predicate.Range; import org.elasticsearch.xpack.esql.core.index.EsIndex; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.session.Configuration; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -49,6 +48,7 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; import org.elasticsearch.xpack.esql.plan.logical.local.LocalSupplier; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ip.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ip.csv-spec index 54d5484bb4172..697b1c899d65e 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ip.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/ip.csv-spec @@ -285,8 +285,8 @@ str1:keyword |str2:keyword |ip1:ip |ip2:ip pushDownIP from hosts | where ip1 == to_ip("::1") | keep card, host, ip0, ip1; ignoreOrder:true -warning:#[Emulated:Line 1:20: evaluation of [ip1 == to_ip(\"::1\")] failed, treating result as null. Only first 20 failures recorded.] -warning:#[Emulated:Line 1:20: java.lang.IllegalArgumentException: single-value function encountered multi-value] +warningRegex:evaluation of \[ip1 == to_ip\(\\\"::1\\\"\)\] failed, treating result as null. Only first 20 failures recorded. +warningRegex:java.lang.IllegalArgumentException: single-value function encountered multi-value card:keyword |host:keyword |ip0:ip |ip1:ip eth1 |alpha |::1 |::1 diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec index 925b2fb9e5533..e7fa027ff1d6e 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/meta.csv-spec @@ -1,4 +1,4 @@ -metaFunctionsSynopsis#[skip:-8.14.99] +metaFunctionsSynopsis#[skip:-8.15.99] meta functions | keep synopsis; synopsis:keyword @@ -38,10 +38,10 @@ double e() "double log(?base:integer|unsigned_long|long|double, number:integer|unsigned_long|long|double)" "double log10(number:double|integer|long|unsigned_long)" "keyword|text ltrim(string:keyword|text)" -"double|integer|long|date max(number:double|integer|long|date)" +"boolean|double|integer|long|date max(field:boolean|double|integer|long|date)" "double|integer|long median(number:double|integer|long)" "double|integer|long median_absolute_deviation(number:double|integer|long)" -"double|integer|long|date min(number:double|integer|long|date)" +"boolean|double|integer|long|date min(field:boolean|double|integer|long|date)" "boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version mv_append(field1:boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version, field2:boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version)" "double mv_avg(number:double|integer|long|unsigned_long)" "keyword mv_concat(string:text|keyword, delim:text|keyword)" @@ -116,7 +116,7 @@ double tau() "double weighted_avg(number:double|integer|long, weight:double|integer|long)" ; -metaFunctionsArgs#[skip:-8.14.99] +metaFunctionsArgs#[skip:-8.15.99] META functions | EVAL name = SUBSTRING(name, 0, 14) | KEEP name, argNames, argTypes, argDescriptions; @@ -158,10 +158,10 @@ locate |[string, substring, start] |["keyword|text", "keyword|te log |[base, number] |["integer|unsigned_long|long|double", "integer|unsigned_long|long|double"] |["Base of logarithm. If `null`\, the function returns `null`. If not provided\, this function returns the natural logarithm (base e) of a value.", "Numeric expression. If `null`\, the function returns `null`."] log10 |number |"double|integer|long|unsigned_long" |Numeric expression. If `null`, the function returns `null`. ltrim |string |"keyword|text" |String expression. If `null`, the function returns `null`. -max |number |"double|integer|long|date" |[""] +max |field |"boolean|double|integer|long|date" |[""] median |number |"double|integer|long" |[""] median_absolut|number |"double|integer|long" |[""] -min |number |"double|integer|long|date" |[""] +min |field |"boolean|double|integer|long|date" |[""] mv_append |[field1, field2] |["boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version", "boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version"] | ["", ""] mv_avg |number |"double|integer|long|unsigned_long" |Multivalue expression. mv_concat |[string, delim] |["text|keyword", "text|keyword"] |[Multivalue expression., Delimiter.] @@ -236,7 +236,7 @@ values |field |"boolean|date|double|integer weighted_avg |[number, weight] |["double|integer|long", "double|integer|long"] |[A numeric value., A numeric weight.] ; -metaFunctionsDescription#[skip:-8.14.99] +metaFunctionsDescription#[skip:-8.15.99] META functions | EVAL name = SUBSTRING(name, 0, 14) | KEEP name, description @@ -279,10 +279,10 @@ locate |Returns an integer that indicates the position of a keyword subst log |Returns the logarithm of a value to a base. The input can be any numeric value, the return value is always a double. Logs of zero, negative numbers, and base of one return `null` as well as a warning. log10 |Returns the logarithm of a value to base 10. The input can be any numeric value, the return value is always a double. Logs of 0 and negative numbers return `null` as well as a warning. ltrim |Removes leading whitespaces from a string. -max |The maximum value of a numeric field. +max |The maximum value of a field. median |The value that is greater than half of all values and less than half of all values. median_absolut|The median absolute deviation, a measure of variability. -min |The minimum value of a numeric field. +min |The minimum value of a field. mv_append |Concatenates values of two multi-value fields. mv_avg |Converts a multivalued field into a single valued field containing the average of all of the values. mv_concat |Converts a multivalued string expression into a single valued column containing the concatenation of all values separated by a delimiter. @@ -357,7 +357,7 @@ values |Collect values for a field. weighted_avg |The weighted average of a numeric field. ; -metaFunctionsRemaining#[skip:-8.14.99] +metaFunctionsRemaining#[skip:-8.15.99] META functions | EVAL name = SUBSTRING(name, 0, 14) | KEEP name, * @@ -401,10 +401,10 @@ locate |integer log |double |[true, false] |false |false log10 |double |false |false |false ltrim |"keyword|text" |false |false |false -max |"double|integer|long|date" |false |false |true +max |"boolean|double|integer|long|date" |false |false |true median |"double|integer|long" |false |false |true median_absolut|"double|integer|long" |false |false |true -min |"double|integer|long|date" |false |false |true +min |"boolean|double|integer|long|date" |false |false |true mv_append |"boolean|cartesian_point|cartesian_shape|date|double|geo_point|geo_shape|integer|ip|keyword|long|text|version" |[false, false] |false |false mv_avg |double |false |false |false mv_concat |keyword |[false, false] |false |false @@ -479,7 +479,7 @@ values |"boolean|date|double|integer|ip|keyword|long|text|version" weighted_avg |"double" |[false, false] |false |true ; -metaFunctionsFiltered#[skip:-8.14.99] +metaFunctionsFiltered#[skip:-8.15.99] META FUNCTIONS | WHERE STARTS_WITH(name, "sin") ; @@ -489,9 +489,7 @@ sin |"double sin(angle:double|integer|long|unsigned_long)" |angle sinh |"double sinh(angle:double|integer|long|unsigned_long)" |angle |"double|integer|long|unsigned_long" | "An angle, in radians. If `null`, the function returns `null`." | double | "Returns the {wikipedia}/Hyperbolic_functions[hyperbolic sine] of an angle." | false | false | false ; - -// see https://github.com/elastic/elasticsearch/issues/102120 -countFunctions#[skip:-8.14.99, reason:BIN added] +countFunctions#[skip:-8.15.99] meta functions | stats a = count(*), b = count(*), c = count(*) | mv_expand c; a:long | b:long | c:long diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec index e4fc0580e4ba2..2d306cd8fd2a0 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/stats.csv-spec @@ -31,6 +31,44 @@ MIN(languages):integer // end::min-result[] ; +maxOfBoolean +required_capability: agg_max_min_boolean_support +from employees | stats s = max(still_hired); + +s:boolean +true +; + +maxOfBooleanExpression +required_capability: agg_max_min_boolean_support +from employees +| eval x = salary is not null +| where emp_no > 10050 +| stats a = max(salary is not null), b = max(x), c = max(case(salary is null, true, false)), d = max(is_rehired); + +a:boolean | b:boolean | c:boolean | d:boolean +true | true | false | true +; + +minOfBooleanExpression +required_capability: agg_max_min_boolean_support +from employees +| eval x = salary is not null +| where emp_no > 10050 +| stats a = min(salary is not null), b = min(x), c = min(case(salary is null, true, false)), d = min(is_rehired); + +a:boolean | b:boolean | c:boolean | d:boolean +true | true | false | false +; + +minOfBoolean +required_capability: agg_max_min_boolean_support +from employees | stats s = min(still_hired); + +s:boolean +false +; + maxOfShort // short becomes int until https://github.com/elastic/elasticsearch-internal/issues/724 from employees | stats l = max(languages.short); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/union_types.csv-spec b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/union_types.csv-spec index ee8c4be385e0f..349f968666132 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/resources/union_types.csv-spec +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/resources/union_types.csv-spec @@ -45,8 +45,10 @@ FROM sample_data_ts_long ; singleIndexIpStats +required_capability: casting_operator + FROM sample_data -| EVAL client_ip = TO_IP(client_ip) +| EVAL client_ip = client_ip::ip | STATS count=count(*) BY client_ip | SORT count DESC, client_ip ASC | KEEP count, client_ip @@ -60,8 +62,10 @@ count:long | client_ip:ip ; singleIndexIpStringStats +required_capability: casting_operator + FROM sample_data_str -| EVAL client_ip = TO_IP(client_ip) +| EVAL client_ip = client_ip::ip | STATS count=count(*) BY client_ip | SORT count DESC, client_ip ASC | KEEP count, client_ip @@ -74,12 +78,28 @@ count:long | client_ip:ip 1 | 172.21.2.162 ; +singleIndexIpStringStatsInline +required_capability: casting_operator + +FROM sample_data_str +| STATS count=count(*) BY client_ip::ip +| STATS mc=count(count) BY count +| SORT mc DESC, count ASC +| KEEP mc, count +; + +mc:l | count:l +3 | 1 +1 | 4 +; + multiIndexIpString required_capability: union_types required_capability: metadata_fields +required_capability: casting_operator FROM sample_data, sample_data_str METADATA _index -| EVAL client_ip = TO_IP(client_ip) +| EVAL client_ip = client_ip::ip | KEEP _index, @timestamp, client_ip, event_duration, message | SORT _index ASC, @timestamp DESC ; @@ -104,9 +124,10 @@ sample_data_str | 2023-10-23T12:15:03.360Z | 172.21.2.162 | 3450233 multiIndexIpStringRename required_capability: union_types required_capability: metadata_fields +required_capability: casting_operator FROM sample_data, sample_data_str METADATA _index -| EVAL host_ip = TO_IP(client_ip) +| EVAL host_ip = client_ip::ip | KEEP _index, @timestamp, host_ip, event_duration, message | SORT _index ASC, @timestamp DESC ; @@ -191,9 +212,10 @@ sample_data_str | 2023-10-23T12:15:03.360Z | 3450233 | Connected multiIndexIpStringStats required_capability: union_types +required_capability: casting_operator FROM sample_data, sample_data_str -| EVAL client_ip = TO_IP(client_ip) +| EVAL client_ip = client_ip::ip | STATS count=count(*) BY client_ip | SORT count DESC, client_ip ASC | KEEP count, client_ip @@ -208,9 +230,10 @@ count:long | client_ip:ip multiIndexIpStringRenameStats required_capability: union_types +required_capability: casting_operator FROM sample_data, sample_data_str -| EVAL host_ip = TO_IP(client_ip) +| EVAL host_ip = client_ip::ip | STATS count=count(*) BY host_ip | SORT count DESC, host_ip ASC | KEEP count, host_ip @@ -240,6 +263,24 @@ count:long | host_ip:keyword 2 | 172.21.2.162 ; +multiIndexIpStringStatsDrop +required_capability: union_types +required_capability: union_types_agg_cast +required_capability: casting_operator + +FROM sample_data, sample_data_str +| STATS count=count(*) BY client_ip::ip +| KEEP count +| SORT count DESC +; + +count:long +8 +2 +2 +2 +; + multiIndexIpStringStatsInline required_capability: union_types required_capability: union_types_inline_fix @@ -257,6 +298,39 @@ count:long | client_ip:ip 2 | 172.21.2.162 ; +multiIndexIpStringStatsInline2 +required_capability: union_types +required_capability: union_types_agg_cast +required_capability: casting_operator + +FROM sample_data, sample_data_str +| STATS count=count(*) BY client_ip::ip +| SORT count DESC, `client_ip::ip` ASC +; + +count:long | client_ip::ip:ip +8 | 172.21.3.15 +2 | 172.21.0.5 +2 | 172.21.2.113 +2 | 172.21.2.162 +; + +multiIndexIpStringStatsInline3 +required_capability: union_types +required_capability: union_types_agg_cast +required_capability: casting_operator + +FROM sample_data, sample_data_str +| STATS count=count(*) BY client_ip::ip +| STATS mc=count(count) BY count +| SORT mc DESC, count ASC +; + +mc:l | count:l +3 | 2 +1 | 8 +; + multiIndexWhereIpStringStats required_capability: union_types @@ -385,6 +459,76 @@ count:long | @timestamp:date 4 | 2023-10-23T12:00:00.000Z ; +multiIndexTsLongStatsDrop +required_capability: union_types +required_capability: union_types_agg_cast +required_capability: casting_operator + +FROM sample_data, sample_data_ts_long +| STATS count=count(*) BY @timestamp::datetime +| KEEP count +; + +count:long +2 +2 +2 +2 +2 +2 +2 +; + +multiIndexTsLongStatsInline2 +required_capability: union_types +required_capability: union_types_agg_cast +required_capability: casting_operator + +FROM sample_data, sample_data_ts_long +| STATS count=count(*) BY @timestamp::datetime +| SORT count DESC, `@timestamp::datetime` DESC +; + +count:long | @timestamp::datetime:datetime +2 | 2023-10-23T13:55:01.543Z +2 | 2023-10-23T13:53:55.832Z +2 | 2023-10-23T13:52:55.015Z +2 | 2023-10-23T13:51:54.732Z +2 | 2023-10-23T13:33:34.937Z +2 | 2023-10-23T12:27:28.948Z +2 | 2023-10-23T12:15:03.360Z +; + +multiIndexTsLongStatsInline3 +required_capability: union_types +required_capability: union_types_agg_cast +required_capability: casting_operator + +FROM sample_data, sample_data_ts_long +| STATS count=count(*) BY @timestamp::datetime +| STATS mc=count(count) BY count +| SORT mc DESC, count ASC +; + +mc:l | count:l +7 | 2 +; + +multiIndexTsLongStatsStats +required_capability: union_types +required_capability: union_types_agg_cast + +FROM sample_data, sample_data_ts_long +| EVAL ts = TO_STRING(@timestamp) +| STATS count = COUNT(*) BY ts +| STATS mc = COUNT(count) BY count +| SORT mc DESC, count ASC +; + +mc:l | count:l +14 | 1 +; + multiIndexTsLongRenameStats required_capability: union_types @@ -717,3 +861,37 @@ null | null | 8268153 | Connection error | samp null | null | 8268153 | Connection error | sample_data_str | 2023-10-23T13:52:55.015Z | 2023-10-23T13:52:55.015Z | 1698069175015 | 172.21.3.15 | 172.21.3.15 null | null | 8268153 | Connection error | sample_data_ts_long | 2023-10-23T13:52:55.015Z | 1698069175015 | 1698069175015 | 172.21.3.15 | 172.21.3.15 ; + +multiIndexMultiColumnTypesRenameAndKeep +required_capability: union_types +required_capability: metadata_fields + +FROM sample_data* METADATA _index +| WHERE event_duration > 8000000 +| EVAL ts = TO_DATETIME(@timestamp), ts_str = TO_STRING(@timestamp), ts_l = TO_LONG(@timestamp), ip = TO_IP(client_ip), ip_str = TO_STRING(client_ip) +| KEEP _index, ts, ts_str, ts_l, ip, ip_str, event_duration +| SORT _index ASC, ts DESC +; + +_index:keyword | ts:date | ts_str:keyword | ts_l:long | ip:ip | ip_str:k | event_duration:long +sample_data | 2023-10-23T13:52:55.015Z | 2023-10-23T13:52:55.015Z | 1698069175015 | 172.21.3.15 | 172.21.3.15 | 8268153 +sample_data_str | 2023-10-23T13:52:55.015Z | 2023-10-23T13:52:55.015Z | 1698069175015 | 172.21.3.15 | 172.21.3.15 | 8268153 +sample_data_ts_long | 2023-10-23T13:52:55.015Z | 1698069175015 | 1698069175015 | 172.21.3.15 | 172.21.3.15 | 8268153 +; + +multiIndexMultiColumnTypesRenameAndDrop +required_capability: union_types +required_capability: metadata_fields + +FROM sample_data* METADATA _index +| WHERE event_duration > 8000000 +| EVAL ts = TO_DATETIME(@timestamp), ts_str = TO_STRING(@timestamp), ts_l = TO_LONG(@timestamp), ip = TO_IP(client_ip), ip_str = TO_STRING(client_ip) +| DROP @timestamp, client_ip, message +| SORT _index ASC, ts DESC +; + +event_duration:long | _index:keyword | ts:date | ts_str:keyword | ts_l:long | ip:ip | ip_str:k +8268153 | sample_data | 2023-10-23T13:52:55.015Z | 2023-10-23T13:52:55.015Z | 1698069175015 | 172.21.3.15 | 172.21.3.15 +8268153 | sample_data_str | 2023-10-23T13:52:55.015Z | 2023-10-23T13:52:55.015Z | 1698069175015 | 172.21.3.15 | 172.21.3.15 +8268153 | sample_data_ts_long | 2023-10-23T13:52:55.015Z | 1698069175015 | 1698069175015 | 172.21.3.15 | 172.21.3.15 +; diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AbstractEsqlIntegTestCase.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AbstractEsqlIntegTestCase.java index 22e3de8499bc1..84738f733f86b 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AbstractEsqlIntegTestCase.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AbstractEsqlIntegTestCase.java @@ -11,6 +11,7 @@ import org.elasticsearch.ElasticsearchTimeoutException; import org.elasticsearch.action.admin.cluster.node.tasks.list.TransportListTasksAction; import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.component.Lifecycle; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; @@ -44,7 +45,11 @@ public void ensureExchangesAreReleased() throws Exception { for (String node : internalCluster().getNodeNames()) { TransportEsqlQueryAction esqlQueryAction = internalCluster().getInstance(TransportEsqlQueryAction.class, node); ExchangeService exchangeService = esqlQueryAction.exchangeService(); - assertBusy(() -> assertTrue("Leftover exchanges " + exchangeService + " on node " + node, exchangeService.isEmpty())); + assertBusy(() -> { + if (exchangeService.lifecycleState() == Lifecycle.State.STARTED) { + assertTrue("Leftover exchanges " + exchangeService + " on node " + node, exchangeService.isEmpty()); + } + }); } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AsyncEsqlQueryActionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AsyncEsqlQueryActionIT.java index da9aa96876fd7..f85de51101af5 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AsyncEsqlQueryActionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/AsyncEsqlQueryActionIT.java @@ -54,7 +54,7 @@ protected Collection> nodePlugins() { @Override protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { return Settings.builder() - .put(ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, TimeValue.timeValueMillis(between(500, 2000))) + .put(ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, TimeValue.timeValueMillis(between(3000, 4000))) .build(); } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java index 800067fef8b1c..df6a1e00b0212 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/CrossClustersCancellationIT.java @@ -68,7 +68,7 @@ public List> getSettings() { return List.of( Setting.timeSetting( ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, - TimeValue.timeValueMillis(between(1000, 3000)), + TimeValue.timeValueMillis(between(3000, 4000)), Setting.Property.NodeScope ) ); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java index 5be816712cf20..cdfa6eb2d03f3 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java @@ -111,7 +111,7 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { HierarchyCircuitBreakerService.REQUEST_CIRCUIT_BREAKER_TYPE_SETTING.getKey(), HierarchyCircuitBreakerService.REQUEST_CIRCUIT_BREAKER_TYPE_SETTING.getDefault(Settings.EMPTY) ) - .put(ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, TimeValue.timeValueMillis(between(500, 2000))) + .put(ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, TimeValue.timeValueMillis(between(3000, 4000))) .put(BlockFactory.LOCAL_BREAKER_OVER_RESERVED_SIZE_SETTING, ByteSizeValue.ofBytes(between(0, 256))) .put(BlockFactory.LOCAL_BREAKER_OVER_RESERVED_MAX_SIZE_SETTING, ByteSizeValue.ofBytes(between(0, 1024))) // allow reading pages from network can trip the circuit breaker diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java index 089cb4a9a5084..37833d8aed2d3 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.action; -import org.apache.lucene.tests.util.LuceneTestCase; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.DocWriteResponse; @@ -35,7 +34,6 @@ import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; -@LuceneTestCase.AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/105543") @TestLogging(value = "org.elasticsearch.xpack.esql:TRACE", reason = "debug") public class EsqlActionBreakerIT extends EsqlActionIT { @@ -72,7 +70,7 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { HierarchyCircuitBreakerService.REQUEST_CIRCUIT_BREAKER_TYPE_SETTING.getKey(), HierarchyCircuitBreakerService.REQUEST_CIRCUIT_BREAKER_TYPE_SETTING.getDefault(Settings.EMPTY) ) - .put(ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, TimeValue.timeValueMillis(between(500, 2000))) + .put(ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, TimeValue.timeValueMillis(between(3000, 4000))) .put(BlockFactory.LOCAL_BREAKER_OVER_RESERVED_SIZE_SETTING, ByteSizeValue.ofBytes(between(0, 256))) .put(BlockFactory.LOCAL_BREAKER_OVER_RESERVED_MAX_SIZE_SETTING, ByteSizeValue.ofBytes(between(0, 1024))) // allow reading pages from network can trip the circuit breaker diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java index 9778756176574..cde4f10ef556c 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java @@ -59,6 +59,7 @@ import static org.hamcrest.Matchers.greaterThan; import static org.hamcrest.Matchers.greaterThanOrEqualTo; import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.in; import static org.hamcrest.Matchers.lessThanOrEqualTo; import static org.hamcrest.Matchers.not; @@ -325,7 +326,7 @@ private void assertCancelled(ActionFuture response) throws Ex */ assertThat( cancelException.getMessage(), - either(equalTo("test cancel")).or(equalTo("task cancelled")).or(equalTo("request cancelled test cancel")) + in(List.of("test cancel", "task cancelled", "request cancelled test cancel", "parent task was cancelled [test cancel]")) ); assertBusy( () -> assertThat( diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java index df1b2c9f00f49..e9eada5def0dc 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java @@ -52,7 +52,7 @@ protected Settings nodeSettings(int nodeOrdinal, Settings otherSettings) { Settings settings = Settings.builder() .put(super.nodeSettings(nodeOrdinal, otherSettings)) .put(DEFAULT_SETTINGS) - .put(ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, TimeValue.timeValueMillis(between(1000, 2000))) + .put(ExchangeService.INACTIVE_SINKS_INTERVAL_SETTING, TimeValue.timeValueMillis(between(3000, 4000))) .build(); logger.info("settings {}", settings); return settings; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/VerificationException.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/VerificationException.java index 99e4a57757e38..8443b8d99d04a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/VerificationException.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/VerificationException.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.esql; -import org.elasticsearch.xpack.esql.core.common.Failure; -import org.elasticsearch.xpack.esql.core.common.Failures; +import org.elasticsearch.xpack.esql.common.Failure; +import org.elasticsearch.xpack.esql.common.Failures; import java.util.Collection; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java index 07362311d37a5..3353a9352a4bb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/action/EsqlCapabilities.java @@ -49,6 +49,11 @@ public enum Cap { */ AGG_TOP, + /** + * Support for booleans in aggregations {@code MAX} and {@code MIN}. + */ + AGG_MAX_MIN_BOOLEAN_SUPPORT, + /** * Optimization for ST_CENTROID changed some results in cartesian data. #108713 */ @@ -106,7 +111,23 @@ public enum Cap { /** * Support for WEIGHTED_AVG function. */ - AGG_WEIGHTED_AVG; + AGG_WEIGHTED_AVG, + + /** + * Fix for union-types when aggregating over an inline conversion with casting operator. Done in #110476. + */ + UNION_TYPES_AGG_CAST, + + /** + * Fix to GROK validation in case of multiple fields with same name and different types + * https://github.com/elastic/elasticsearch/issues/110533 + */ + GROK_VALIDATION, + + /** + * Fix for union-types when aggregating over an inline conversion with conversion function. Done in #110652. + */ + UNION_TYPES_INLINE_FIX; private final boolean snapshotOnly; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java index 4fcd37faa311a..add1f74cc3f04 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Analyzer.java @@ -14,11 +14,10 @@ import org.elasticsearch.xpack.esql.Column; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.VerificationException; -import org.elasticsearch.xpack.esql.core.analyzer.AnalyzerRules; -import org.elasticsearch.xpack.esql.core.analyzer.AnalyzerRules.BaseAnalyzerRule; -import org.elasticsearch.xpack.esql.core.analyzer.AnalyzerRules.ParameterizedAnalyzerRule; +import org.elasticsearch.xpack.esql.analysis.AnalyzerRules.BaseAnalyzerRule; +import org.elasticsearch.xpack.esql.analysis.AnalyzerRules.ParameterizedAnalyzerRule; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.capabilities.Resolvables; -import org.elasticsearch.xpack.esql.core.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeMap; @@ -38,8 +37,6 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.esql.core.index.EsIndex; import org.elasticsearch.xpack.esql.core.plan.TableIdentifier; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.rule.ParameterizedRule; import org.elasticsearch.xpack.esql.core.rule.ParameterizedRuleExecutor; import org.elasticsearch.xpack.esql.core.rule.RuleExecutor; @@ -71,6 +68,8 @@ import org.elasticsearch.xpack.esql.plan.logical.EsqlUnresolvedRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.Keep; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Lookup; import org.elasticsearch.xpack.esql.plan.logical.MvExpand; import org.elasticsearch.xpack.esql.plan.logical.Project; @@ -1087,6 +1086,23 @@ protected LogicalPlan doRule(LogicalPlan plan) { return plan; } + // In ResolveRefs the aggregates are resolved from the groupings, which might have an unresolved MultiTypeEsField. + // Now that we have resolved those, we need to re-resolve the aggregates. + if (plan instanceof EsqlAggregate agg) { + // If the union-types resolution occurred in a child of the aggregate, we need to check the groupings + plan = agg.transformExpressionsOnly(FieldAttribute.class, UnresolveUnionTypes::checkUnresolved); + + // Aggregates where the grouping key comes from a union-type field need to be resolved against the grouping key + Map resolved = new HashMap<>(); + for (Expression e : agg.groupings()) { + Attribute attr = Expressions.attribute(e); + if (attr != null && attr.resolved()) { + resolved.put(attr, e); + } + } + plan = plan.transformExpressionsOnly(UnresolvedAttribute.class, ua -> resolveAttribute(ua, resolved)); + } + // Otherwise drop the converted attributes after the alias function, as they are only needed for this function, and // the original version of the attribute should still be seen as unconverted. plan = dropConvertedAttributes(plan, unionFieldAttributes); @@ -1110,6 +1126,15 @@ protected LogicalPlan doRule(LogicalPlan plan) { return plan; } + private Expression resolveAttribute(UnresolvedAttribute ua, Map resolved) { + var named = resolveAgainstList(ua, resolved.keySet()); + return switch (named.size()) { + case 0 -> ua; + case 1 -> named.get(0).equals(ua) ? ua : resolved.get(named.get(0)); + default -> ua.withUnresolvedMessage("Resolved [" + ua + "] unexpectedly to multiple attributes " + named); + }; + } + private LogicalPlan dropConvertedAttributes(LogicalPlan plan, List unionFieldAttributes) { List projections = new ArrayList<>(plan.output()); for (var e : unionFieldAttributes) { @@ -1201,9 +1226,8 @@ protected LogicalPlan rule(LogicalPlan plan) { return plan.transformExpressionsOnly(FieldAttribute.class, UnresolveUnionTypes::checkUnresolved); } - private static Attribute checkUnresolved(FieldAttribute fa) { - var field = fa.field(); - if (field instanceof InvalidMappedField imf) { + static Attribute checkUnresolved(FieldAttribute fa) { + if (fa.field() instanceof InvalidMappedField imf) { String unresolvedMessage = "Cannot use field [" + fa.name() + "] due to ambiguities being " + imf.errorMessage(); return new UnresolvedAttribute(fa.source(), fa.name(), fa.qualifier(), fa.id(), unresolvedMessage, null); } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/analyzer/AnalyzerRules.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/AnalyzerRules.java similarity index 97% rename from x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/analyzer/AnalyzerRules.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/AnalyzerRules.java index ce188511fe7bc..3314129fae405 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/analyzer/AnalyzerRules.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/AnalyzerRules.java @@ -5,13 +5,13 @@ * 2.0. */ -package org.elasticsearch.xpack.esql.core.analyzer; +package org.elasticsearch.xpack.esql.analysis; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.rule.ParameterizedRule; import org.elasticsearch.xpack.esql.core.rule.Rule; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.util.ArrayList; import java.util.Collection; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java index 7c37d5b8392c5..790142bef6a86 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/PreAnalyzer.java @@ -8,9 +8,9 @@ package org.elasticsearch.xpack.esql.analysis; import org.elasticsearch.xpack.esql.core.analyzer.TableInfo; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.EsqlUnresolvedRelation; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.util.ArrayList; import java.util.List; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java index 514a53b0933e9..a4e0d99b0d3fc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/analysis/Verifier.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.esql.analysis; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.capabilities.Unresolvable; -import org.elasticsearch.xpack.esql.core.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeMap; @@ -20,10 +20,6 @@ import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryOperator; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; @@ -35,10 +31,15 @@ import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Lookup; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.RegexExtract; import org.elasticsearch.xpack.esql.plan.logical.Row; +import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.stats.FeatureMetric; import org.elasticsearch.xpack.esql.stats.Metrics; import org.elasticsearch.xpack.esql.type.EsqlDataTypes; @@ -52,9 +53,9 @@ import java.util.function.Consumer; import java.util.stream.Stream; -import static org.elasticsearch.xpack.esql.core.analyzer.VerifierChecks.checkFilterConditionType; -import static org.elasticsearch.xpack.esql.core.common.Failure.fail; +import static org.elasticsearch.xpack.esql.common.Failure.fail; import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.FIRST; +import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; public class Verifier { @@ -177,6 +178,15 @@ else if (p instanceof Lookup lookup) { return failures; } + private static void checkFilterConditionType(LogicalPlan p, Set localFailures) { + if (p instanceof Filter f) { + Expression condition = f.condition(); + if (condition.dataType() != BOOLEAN) { + localFailures.add(fail(condition, "Condition expression needs to be boolean, found [{}]", condition.dataType())); + } + } + } + private static void checkAggregate(LogicalPlan p, Set failures) { if (p instanceof Aggregate agg) { List groupings = agg.groupings(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/capabilities/Validatable.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/capabilities/Validatable.java index 4d30f32af5f15..f6733fa3f175c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/capabilities/Validatable.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/capabilities/Validatable.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.esql.capabilities; -import org.elasticsearch.xpack.esql.core.common.Failures; +import org.elasticsearch.xpack.esql.common.Failures; /** * Interface implemented by expressions that require validation post logical optimization, diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/common/Failure.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/common/Failure.java similarity index 97% rename from x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/common/Failure.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/common/Failure.java index 719ae7ffbd1ca..e5d0fb7ba0b3d 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/common/Failure.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/common/Failure.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.esql.core.common; +package org.elasticsearch.xpack.esql.common; import org.elasticsearch.xpack.esql.core.tree.Location; import org.elasticsearch.xpack.esql.core.tree.Node; diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/common/Failures.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/common/Failures.java similarity index 96% rename from x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/common/Failures.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/common/Failures.java index c06fe94c9a338..fd25cb427d95b 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/common/Failures.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/common/Failures.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.esql.core.common; +package org.elasticsearch.xpack.esql.common; import java.util.Collection; import java.util.LinkedHashSet; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java index 87c558fe5bd1e..2425fa24b17c2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/enrich/EnrichLookupService.java @@ -31,7 +31,6 @@ import org.elasticsearch.compute.data.BlockStreamInput; import org.elasticsearch.compute.data.BytesRefBlock; import org.elasticsearch.compute.data.ElementType; -import org.elasticsearch.compute.data.IntBlock; import org.elasticsearch.compute.data.IntVector; import org.elasticsearch.compute.data.LocalCircuitBreaker; import org.elasticsearch.compute.data.OrdinalBytesRefBlock; @@ -43,6 +42,7 @@ import org.elasticsearch.compute.operator.OutputOperator; import org.elasticsearch.core.AbstractRefCounted; import org.elasticsearch.core.RefCounted; +import org.elasticsearch.core.Releasable; import org.elasticsearch.core.Releasables; import org.elasticsearch.index.mapper.BlockLoader; import org.elasticsearch.index.mapper.MappedFieldType; @@ -247,30 +247,53 @@ private void doLookup( ActionListener listener ) { Block inputBlock = inputPage.getBlock(0); - final IntBlock selectedPositions; - final OrdinalBytesRefBlock ordinalsBytesRefBlock; - if (inputBlock instanceof BytesRefBlock bytesRefBlock && (ordinalsBytesRefBlock = bytesRefBlock.asOrdinals()) != null) { - inputBlock = ordinalsBytesRefBlock.getDictionaryVector().asBlock(); - selectedPositions = ordinalsBytesRefBlock.getOrdinalsBlock(); - selectedPositions.mustIncRef(); - } else { - selectedPositions = IntVector.range(0, inputBlock.getPositionCount(), blockFactory).asBlock(); + if (inputBlock.areAllValuesNull()) { + listener.onResponse(createNullResponse(inputPage.getPositionCount(), extractFields)); + return; } - LocalCircuitBreaker localBreaker = null; + final List releasables = new ArrayList<>(6); + boolean started = false; try { - if (inputBlock.areAllValuesNull()) { - listener.onResponse(createNullResponse(inputPage.getPositionCount(), extractFields)); - return; - } - ShardSearchRequest shardSearchRequest = new ShardSearchRequest(shardId, 0, AliasFilter.EMPTY); - SearchContext searchContext = searchService.createSearchContext(shardSearchRequest, SearchService.NO_TIMEOUT); - listener = ActionListener.runBefore(listener, searchContext::close); - localBreaker = new LocalCircuitBreaker( + final ShardSearchRequest shardSearchRequest = new ShardSearchRequest(shardId, 0, AliasFilter.EMPTY); + final SearchContext searchContext = searchService.createSearchContext(shardSearchRequest, SearchService.NO_TIMEOUT); + releasables.add(searchContext); + final LocalCircuitBreaker localBreaker = new LocalCircuitBreaker( blockFactory.breaker(), localBreakerSettings.overReservedBytes(), localBreakerSettings.maxOverReservedBytes() ); - DriverContext driverContext = new DriverContext(bigArrays, blockFactory.newChildFactory(localBreaker)); + releasables.add(localBreaker); + final DriverContext driverContext = new DriverContext(bigArrays, blockFactory.newChildFactory(localBreaker)); + final ElementType[] mergingTypes = new ElementType[extractFields.size()]; + for (int i = 0; i < extractFields.size(); i++) { + mergingTypes[i] = PlannerUtils.toElementType(extractFields.get(i).dataType()); + } + final int[] mergingChannels = IntStream.range(0, extractFields.size()).map(i -> i + 2).toArray(); + final MergePositionsOperator mergePositionsOperator; + final OrdinalBytesRefBlock ordinalsBytesRefBlock; + if (inputBlock instanceof BytesRefBlock bytesRefBlock && (ordinalsBytesRefBlock = bytesRefBlock.asOrdinals()) != null) { + inputBlock = ordinalsBytesRefBlock.getDictionaryVector().asBlock(); + var selectedPositions = ordinalsBytesRefBlock.getOrdinalsBlock(); + mergePositionsOperator = new MergePositionsOperator( + 1, + mergingChannels, + mergingTypes, + selectedPositions, + driverContext.blockFactory() + ); + + } else { + try (var selectedPositions = IntVector.range(0, inputBlock.getPositionCount(), blockFactory).asBlock()) { + mergePositionsOperator = new MergePositionsOperator( + 1, + mergingChannels, + mergingTypes, + selectedPositions, + driverContext.blockFactory() + ); + } + } + releasables.add(mergePositionsOperator); SearchExecutionContext searchExecutionContext = searchContext.getSearchExecutionContext(); MappedFieldType fieldType = searchExecutionContext.getFieldType(matchField); var queryList = switch (matchType) { @@ -284,57 +307,13 @@ private void doLookup( queryList, searchExecutionContext.getIndexReader() ); - List intermediateOperators = new ArrayList<>(extractFields.size() + 2); - final ElementType[] mergingTypes = new ElementType[extractFields.size()]; - // load the fields - List fields = new ArrayList<>(extractFields.size()); - for (int i = 0; i < extractFields.size(); i++) { - NamedExpression extractField = extractFields.get(i); - final ElementType elementType = PlannerUtils.toElementType(extractField.dataType()); - mergingTypes[i] = elementType; - EsPhysicalOperationProviders.ShardContext ctx = new EsPhysicalOperationProviders.DefaultShardContext( - 0, - searchContext.getSearchExecutionContext(), - searchContext.request().getAliasFilter() - ); - BlockLoader loader = ctx.blockLoader( - extractField instanceof Alias a ? ((NamedExpression) a.child()).name() : extractField.name(), - extractField.dataType() == DataType.UNSUPPORTED, - MappedFieldType.FieldExtractPreference.NONE - ); - fields.add( - new ValuesSourceReaderOperator.FieldInfo( - extractField.name(), - PlannerUtils.toElementType(extractField.dataType()), - shardIdx -> { - if (shardIdx != 0) { - throw new IllegalStateException("only one shard"); - } - return loader; - } - ) - ); - } - intermediateOperators.add( - new ValuesSourceReaderOperator( - driverContext.blockFactory(), - fields, - List.of( - new ValuesSourceReaderOperator.ShardContext( - searchContext.searcher().getIndexReader(), - searchContext::newSourceLoader - ) - ), - 0 - ) - ); - // merging field-values by position - final int[] mergingChannels = IntStream.range(0, extractFields.size()).map(i -> i + 2).toArray(); - intermediateOperators.add( - new MergePositionsOperator(1, mergingChannels, mergingTypes, selectedPositions, driverContext.blockFactory()) - ); + releasables.add(queryOperator); + var extractFieldsOperator = extractFieldsOperator(searchContext, driverContext, extractFields); + releasables.add(extractFieldsOperator); + AtomicReference result = new AtomicReference<>(); OutputOperator outputOperator = new OutputOperator(List.of(), Function.identity(), result::set); + releasables.add(outputOperator); Driver driver = new Driver( "enrich-lookup:" + sessionId, System.currentTimeMillis(), @@ -350,18 +329,16 @@ private void doLookup( inputPage.getPositionCount() ), queryOperator, - intermediateOperators, + List.of(extractFieldsOperator, mergePositionsOperator), outputOperator, Driver.DEFAULT_STATUS_INTERVAL, - localBreaker + Releasables.wrap(searchContext, localBreaker) ); task.addListener(() -> { String reason = Objects.requireNonNullElse(task.getReasonCancelled(), "task was cancelled"); driver.cancel(reason); }); - var threadContext = transportService.getThreadPool().getThreadContext(); - localBreaker = null; Driver.start(threadContext, executor, driver, Driver.DEFAULT_MAX_ITERATIONS, listener.map(ignored -> { Page out = result.get(); if (out == null) { @@ -369,11 +346,52 @@ private void doLookup( } return out; })); + started = true; } catch (Exception e) { listener.onFailure(e); } finally { - Releasables.close(selectedPositions, localBreaker); + if (started == false) { + Releasables.close(releasables); + } + } + } + + private static Operator extractFieldsOperator( + SearchContext searchContext, + DriverContext driverContext, + List extractFields + ) { + EsPhysicalOperationProviders.ShardContext shardContext = new EsPhysicalOperationProviders.DefaultShardContext( + 0, + searchContext.getSearchExecutionContext(), + searchContext.request().getAliasFilter() + ); + List fields = new ArrayList<>(extractFields.size()); + for (NamedExpression extractField : extractFields) { + BlockLoader loader = shardContext.blockLoader( + extractField instanceof Alias a ? ((NamedExpression) a.child()).name() : extractField.name(), + extractField.dataType() == DataType.UNSUPPORTED, + MappedFieldType.FieldExtractPreference.NONE + ); + fields.add( + new ValuesSourceReaderOperator.FieldInfo( + extractField.name(), + PlannerUtils.toElementType(extractField.dataType()), + shardIdx -> { + if (shardIdx != 0) { + throw new IllegalStateException("only one shard"); + } + return loader; + } + ) + ); } + return new ValuesSourceReaderOperator( + driverContext.blockFactory(), + fields, + List.of(new ValuesSourceReaderOperator.ShardContext(searchContext.searcher().getIndexReader(), searchContext::newSourceLoader)), + 0 + ); } private Page createNullResponse(int positionCount, List extractFields) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java index df67f4609c33e..4e07c3084ab7b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/execution/PlanExecutor.java @@ -20,9 +20,12 @@ import org.elasticsearch.xpack.esql.session.EsqlConfiguration; import org.elasticsearch.xpack.esql.session.EsqlSession; import org.elasticsearch.xpack.esql.session.IndexResolver; +import org.elasticsearch.xpack.esql.session.Result; import org.elasticsearch.xpack.esql.stats.Metrics; import org.elasticsearch.xpack.esql.stats.QueryMetric; +import java.util.function.BiConsumer; + import static org.elasticsearch.action.ActionListener.wrap; public class PlanExecutor { @@ -48,7 +51,8 @@ public void esql( String sessionId, EsqlConfiguration cfg, EnrichPolicyResolver enrichPolicyResolver, - ActionListener listener + BiConsumer> runPhase, + ActionListener listener ) { final var session = new EsqlSession( sessionId, @@ -63,7 +67,7 @@ public void esql( ); QueryMetric clientId = QueryMetric.fromString("rest"); metrics.total(clientId); - session.execute(request, wrap(listener::onResponse, ex -> { + session.execute(request, runPhase, wrap(listener::onResponse, ex -> { // TODO when we decide if we will differentiate Kibana from REST, this String value will likely come from the request metrics.failed(clientId); listener.onFailure(ex); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/Validations.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/Validations.java index dffa723a1f3dd..ffcc26cb6f188 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/Validations.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/Validations.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.esql.expression; -import org.elasticsearch.xpack.esql.core.common.Failure; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expression.TypeResolution; import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java index cb70b73117397..b5c0b8e5ffdc8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Avg.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.SurrogateExpression; +import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvAvg; @@ -28,7 +29,20 @@ public class Avg extends AggregateFunction implements SurrogateExpression { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Avg", Avg::new); - @FunctionInfo(returnType = "double", description = "The average of a numeric field.", isAggregation = true) + @FunctionInfo( + returnType = "double", + description = "The average of a numeric field.", + isAggregation = true, + examples = { + @Example(file = "stats", tag = "avg"), + @Example( + description = "The expression can use inline functions. For example, to calculate the average " + + "over a multivalued column, first use `MV_AVG` to average the multiple values per row, " + + "and use the result with the `AVG` function", + file = "stats", + tag = "docsStatsAvgNestedExpression" + ) } + ) public Avg(Source source, @Param(name = "number", type = { "double", "integer", "long" }) Expression field) { super(source, field); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java index 97a6f6b4b5e1f..98748fad681c2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Max.java @@ -10,30 +10,46 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.MaxBooleanAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MaxDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MaxIntAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MaxLongAggregatorFunctionSupplier; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.SurrogateExpression; +import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMax; +import org.elasticsearch.xpack.esql.planner.ToAggregator; import java.io.IOException; import java.util.List; -public class Max extends NumericAggregate implements SurrogateExpression { +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; + +public class Max extends AggregateFunction implements ToAggregator, SurrogateExpression { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Max", Max::new); @FunctionInfo( - returnType = { "double", "integer", "long", "date" }, - description = "The maximum value of a numeric field.", - isAggregation = true + returnType = { "boolean", "double", "integer", "long", "date" }, + description = "The maximum value of a field.", + isAggregation = true, + examples = { + @Example(file = "stats", tag = "max"), + @Example( + description = "The expression can use inline functions. For example, to calculate the maximum " + + "over an average of a multivalued column, use `MV_AVG` to first average the " + + "multiple values per row, and use the result with the `MAX` function", + file = "stats", + tag = "docsStatsMaxNestedExpression" + ) } ) - public Max(Source source, @Param(name = "number", type = { "double", "integer", "long", "date" }) Expression field) { + public Max(Source source, @Param(name = "field", type = { "boolean", "double", "integer", "long", "date" }) Expression field) { super(source, field); } @@ -57,8 +73,16 @@ public Max replaceChildren(List newChildren) { } @Override - protected boolean supportsDates() { - return true; + protected TypeResolution resolveType() { + return TypeResolutions.isType( + this, + e -> e == DataType.BOOLEAN || e == DataType.DATETIME || (e.isNumeric() && e != DataType.UNSIGNED_LONG), + sourceText(), + DEFAULT, + "boolean", + "datetime", + "numeric except unsigned_long or counter types" + ); } @Override @@ -67,18 +91,21 @@ public DataType dataType() { } @Override - protected AggregatorFunctionSupplier longSupplier(List inputChannels) { - return new MaxLongAggregatorFunctionSupplier(inputChannels); - } - - @Override - protected AggregatorFunctionSupplier intSupplier(List inputChannels) { - return new MaxIntAggregatorFunctionSupplier(inputChannels); - } - - @Override - protected AggregatorFunctionSupplier doubleSupplier(List inputChannels) { - return new MaxDoubleAggregatorFunctionSupplier(inputChannels); + public final AggregatorFunctionSupplier supplier(List inputChannels) { + DataType type = field().dataType(); + if (type == DataType.BOOLEAN) { + return new MaxBooleanAggregatorFunctionSupplier(inputChannels); + } + if (type == DataType.LONG || type == DataType.DATETIME) { + return new MaxLongAggregatorFunctionSupplier(inputChannels); + } + if (type == DataType.INTEGER) { + return new MaxIntAggregatorFunctionSupplier(inputChannels); + } + if (type == DataType.DOUBLE) { + return new MaxDoubleAggregatorFunctionSupplier(inputChannels); + } + throw EsqlIllegalArgumentException.illegalDataType(type); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java index 2dd3e973937f5..f712786bcff4b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/Min.java @@ -10,30 +10,46 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.compute.aggregation.AggregatorFunctionSupplier; +import org.elasticsearch.compute.aggregation.MinBooleanAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MinDoubleAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MinIntAggregatorFunctionSupplier; import org.elasticsearch.compute.aggregation.MinLongAggregatorFunctionSupplier; +import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.TypeResolutions; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.SurrogateExpression; +import org.elasticsearch.xpack.esql.expression.function.Example; import org.elasticsearch.xpack.esql.expression.function.FunctionInfo; import org.elasticsearch.xpack.esql.expression.function.Param; import org.elasticsearch.xpack.esql.expression.function.scalar.multivalue.MvMin; +import org.elasticsearch.xpack.esql.planner.ToAggregator; import java.io.IOException; import java.util.List; -public class Min extends NumericAggregate implements SurrogateExpression { +import static org.elasticsearch.xpack.esql.core.expression.TypeResolutions.ParamOrdinal.DEFAULT; + +public class Min extends AggregateFunction implements ToAggregator, SurrogateExpression { public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry(Expression.class, "Min", Min::new); @FunctionInfo( - returnType = { "double", "integer", "long", "date" }, - description = "The minimum value of a numeric field.", - isAggregation = true + returnType = { "boolean", "double", "integer", "long", "date" }, + description = "The minimum value of a field.", + isAggregation = true, + examples = { + @Example(file = "stats", tag = "min"), + @Example( + description = "The expression can use inline functions. For example, to calculate the minimum " + + "over an average of a multivalued column, use `MV_AVG` to first average the " + + "multiple values per row, and use the result with the `MIN` function", + file = "stats", + tag = "docsStatsMinNestedExpression" + ) } ) - public Min(Source source, @Param(name = "number", type = { "double", "integer", "long", "date" }) Expression field) { + public Min(Source source, @Param(name = "field", type = { "boolean", "double", "integer", "long", "date" }) Expression field) { super(source, field); } @@ -57,28 +73,39 @@ public Min replaceChildren(List newChildren) { } @Override - public DataType dataType() { - return field().dataType(); - } - - @Override - protected boolean supportsDates() { - return true; + protected TypeResolution resolveType() { + return TypeResolutions.isType( + this, + e -> e == DataType.BOOLEAN || e == DataType.DATETIME || (e.isNumeric() && e != DataType.UNSIGNED_LONG), + sourceText(), + DEFAULT, + "boolean", + "datetime", + "numeric except unsigned_long or counter types" + ); } @Override - protected AggregatorFunctionSupplier longSupplier(List inputChannels) { - return new MinLongAggregatorFunctionSupplier(inputChannels); - } - - @Override - protected AggregatorFunctionSupplier intSupplier(List inputChannels) { - return new MinIntAggregatorFunctionSupplier(inputChannels); + public DataType dataType() { + return field().dataType(); } @Override - protected AggregatorFunctionSupplier doubleSupplier(List inputChannels) { - return new MinDoubleAggregatorFunctionSupplier(inputChannels); + public final AggregatorFunctionSupplier supplier(List inputChannels) { + DataType type = field().dataType(); + if (type == DataType.BOOLEAN) { + return new MinBooleanAggregatorFunctionSupplier(inputChannels); + } + if (type == DataType.LONG || type == DataType.DATETIME) { + return new MinLongAggregatorFunctionSupplier(inputChannels); + } + if (type == DataType.INTEGER) { + return new MinIntAggregatorFunctionSupplier(inputChannels); + } + if (type == DataType.DOUBLE) { + return new MinDoubleAggregatorFunctionSupplier(inputChannels); + } + throw EsqlIllegalArgumentException.illegalDataType(type); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ToPartial.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ToPartial.java index f94c8e0508cd7..c1da400185944 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ToPartial.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/aggregate/ToPartial.java @@ -65,12 +65,7 @@ public class ToPartial extends AggregateFunction implements ToAggregator { private final Expression function; - public ToPartial(Source source, AggregateFunction function) { - super(source, function.field(), List.of(function)); - this.function = function; - } - - private ToPartial(Source source, Expression field, Expression function) { + public ToPartial(Source source, Expression field, Expression function) { super(source, field, List.of(function)); this.function = function; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java index 40e927404befd..3ce51b8086dd0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/grouping/Bucket.java @@ -16,7 +16,7 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.capabilities.Validatable; -import org.elasticsearch.xpack.esql.core.common.Failures; +import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Foldables; import org.elasticsearch.xpack.esql.core.expression.Literal; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSort.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSort.java index 444c0e319fc6a..ee83236ac6a63 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSort.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSort.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.logging.LoggerMessageFormat; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BooleanBlock; @@ -29,7 +30,8 @@ import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeInt; import org.elasticsearch.compute.operator.mvdedupe.MultivalueDedupeLong; import org.elasticsearch.xpack.esql.capabilities.Validatable; -import org.elasticsearch.xpack.esql.core.common.Failures; +import org.elasticsearch.xpack.esql.common.Failure; +import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; @@ -64,6 +66,9 @@ public class MvSort extends EsqlScalarFunction implements OptionalArgument, Vali private final Expression field, order; private static final Literal ASC = new Literal(Source.EMPTY, "ASC", DataType.KEYWORD); + private static final Literal DESC = new Literal(Source.EMPTY, "DESC", DataType.KEYWORD); + + private static final String INVALID_ORDER_ERROR = "Invalid order value in [{}], expected one of [{}, {}] but got [{}]"; @FunctionInfo( returnType = { "boolean", "date", "double", "integer", "ip", "keyword", "long", "text", "version" }, @@ -84,7 +89,7 @@ public MvSort( optional = true ) Expression order ) { - super(source, order == null ? Arrays.asList(field, ASC) : Arrays.asList(field, order)); + super(source, order == null ? Arrays.asList(field) : Arrays.asList(field, order)); this.field = field; this.order = order; } @@ -128,6 +133,7 @@ protected TypeResolution resolveType() { if (resolution.unresolved()) { return resolution; } + if (order == null) { return resolution; } @@ -144,10 +150,23 @@ public boolean foldable() { public EvalOperator.ExpressionEvaluator.Factory toEvaluator( Function toEvaluator ) { - Expression nonNullOrder = order == null ? ASC : order; - boolean ordering = nonNullOrder.foldable() && ((BytesRef) nonNullOrder.fold()).utf8ToString().equalsIgnoreCase("DESC") - ? false - : true; + boolean ordering = true; + if (isValidOrder() == false) { + throw new IllegalArgumentException( + LoggerMessageFormat.format( + null, + INVALID_ORDER_ERROR, + sourceText(), + ASC.value(), + DESC.value(), + ((BytesRef) order.fold()).utf8ToString() + ) + ); + } + if (order != null && order.foldable()) { + ordering = ((BytesRef) order.fold()).utf8ToString().equalsIgnoreCase((String) ASC.value()); + } + return switch (PlannerUtils.toElementType(field.dataType())) { case BOOLEAN -> new MvSort.EvaluatorFactory( toEvaluator.apply(field), @@ -216,8 +235,33 @@ public DataType dataType() { @Override public void validate(Failures failures) { + if (order == null) { + return; + } String operation = sourceText(); failures.add(isFoldable(order, operation, SECOND)); + if (isValidOrder() == false) { + failures.add( + Failure.fail(order, INVALID_ORDER_ERROR, sourceText(), ASC.value(), DESC.value(), ((BytesRef) order.fold()).utf8ToString()) + ); + } + } + + private boolean isValidOrder() { + boolean isValidOrder = true; + if (order != null && order.foldable()) { + Object obj = order.fold(); + String o = null; + if (obj instanceof BytesRef ob) { + o = ob.utf8ToString(); + } else if (obj instanceof String os) { + o = os; + } + if (o == null || o.equalsIgnoreCase((String) ASC.value()) == false && o.equalsIgnoreCase((String) DESC.value()) == false) { + isValidOrder = false; + } + } + return isValidOrder; } private record EvaluatorFactory( diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java index 8034eba20690d..e4051523c7a5e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java @@ -26,10 +26,6 @@ import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.Order; import org.elasticsearch.xpack.esql.core.index.EsIndex; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; @@ -38,9 +34,13 @@ import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.Grok; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Lookup; import org.elasticsearch.xpack.esql.plan.logical.MvExpand; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.TopN; import org.elasticsearch.xpack.esql.plan.logical.join.Join; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java index be2a9454b3bef..0633595a5796d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java @@ -24,9 +24,9 @@ import org.elasticsearch.core.Releasables; import org.elasticsearch.xpack.esql.Column; import org.elasticsearch.xpack.esql.core.expression.NameId; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanNamedReader; import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanReader; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.session.EsqlConfiguration; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java index 58cd2465e1584..674476ec4f736 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java @@ -19,8 +19,8 @@ import org.elasticsearch.compute.data.LongBigArrayBlock; import org.elasticsearch.core.Nullable; import org.elasticsearch.xpack.esql.Column; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry.PlanWriter; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.join.Join; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.session.EsqlConfiguration; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java index ba5e8316a666c..9a2ae742c2feb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizer.java @@ -21,11 +21,6 @@ import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.core.rule.ParameterizedRule; import org.elasticsearch.xpack.esql.core.rule.ParameterizedRuleExecutor; import org.elasticsearch.xpack.esql.core.rule.Rule; @@ -34,10 +29,15 @@ import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; +import org.elasticsearch.xpack.esql.optimizer.rules.OptimizerRules; import org.elasticsearch.xpack.esql.optimizer.rules.PropagateEmptyRelation; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.RegexExtract; import org.elasticsearch.xpack.esql.plan.logical.TopN; @@ -54,9 +54,9 @@ import static java.util.Arrays.asList; import static java.util.Collections.emptySet; -import static org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.TransformDirection.UP; import static org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer.cleanup; import static org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer.operators; +import static org.elasticsearch.xpack.esql.optimizer.rules.OptimizerRules.TransformDirection.UP; public class LocalLogicalPlanOptimizer extends ParameterizedRuleExecutor { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java index 9447e018bc142..c03dc46216621 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LocalPhysicalPlanOptimizer.java @@ -17,7 +17,7 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.xpack.esql.VerificationException; -import org.elasticsearch.xpack.esql.core.common.Failure; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeMap; @@ -92,7 +92,7 @@ import static java.util.Collections.emptyList; import static java.util.Collections.singletonList; import static org.elasticsearch.xpack.esql.core.expression.predicate.Predicates.splitAnd; -import static org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.TransformDirection.UP; +import static org.elasticsearch.xpack.esql.optimizer.rules.OptimizerRules.TransformDirection.UP; import static org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec.StatsType.COUNT; public class LocalPhysicalPlanOptimizer extends ParameterizedRuleExecutor { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index ca4b5d17deed3..50819b8ee7480 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -9,7 +9,7 @@ import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.VerificationException; -import org.elasticsearch.xpack.esql.core.common.Failures; +import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeMap; @@ -17,9 +17,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.Order; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.rule.ParameterizedRule; import org.elasticsearch.xpack.esql.core.rule.ParameterizedRuleExecutor; import org.elasticsearch.xpack.esql.optimizer.rules.AddDefaultTopN; @@ -68,7 +65,10 @@ import org.elasticsearch.xpack.esql.optimizer.rules.SubstituteSurrogates; import org.elasticsearch.xpack.esql.optimizer.rules.TranslateMetricsAggregate; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Project; +import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; import org.elasticsearch.xpack.esql.plan.logical.local.LocalSupplier; import org.elasticsearch.xpack.esql.type.EsqlDataTypes; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java index 2387a4a210de3..cd61b4eb8892c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalVerifier.java @@ -8,9 +8,9 @@ package org.elasticsearch.xpack.esql.optimizer; import org.elasticsearch.xpack.esql.capabilities.Validatable; -import org.elasticsearch.xpack.esql.core.common.Failures; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.optimizer.OptimizerRules.LogicalPlanDependencyCheck; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; public final class LogicalVerifier { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java index 4c5d9efb449f7..bff76fb1a706e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRules.java @@ -7,15 +7,15 @@ package org.elasticsearch.xpack.esql.optimizer; -import org.elasticsearch.xpack.esql.core.common.Failures; +import org.elasticsearch.xpack.esql.common.Failures; import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.plan.QueryPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.MvExpand; import org.elasticsearch.xpack.esql.plan.logical.RegexExtract; import org.elasticsearch.xpack.esql.plan.logical.Row; @@ -36,7 +36,7 @@ import org.elasticsearch.xpack.esql.plan.physical.RowExec; import org.elasticsearch.xpack.esql.plan.physical.ShowExec; -import static org.elasticsearch.xpack.esql.core.common.Failure.fail; +import static org.elasticsearch.xpack.esql.common.Failure.fail; class OptimizerRules { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalOptimizerRules.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalOptimizerRules.java index 1def5a4133a3f..c669853d3357e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalOptimizerRules.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalOptimizerRules.java @@ -8,10 +8,10 @@ package org.elasticsearch.xpack.esql.optimizer; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.TransformDirection; import org.elasticsearch.xpack.esql.core.rule.ParameterizedRule; import org.elasticsearch.xpack.esql.core.rule.Rule; import org.elasticsearch.xpack.esql.core.util.ReflectionUtils; +import org.elasticsearch.xpack.esql.optimizer.rules.OptimizerRules.TransformDirection; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; public class PhysicalOptimizerRules { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java index 70c2a9007408a..e9fd6a713945c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizer.java @@ -8,7 +8,7 @@ package org.elasticsearch.xpack.esql.optimizer; import org.elasticsearch.xpack.esql.VerificationException; -import org.elasticsearch.xpack.esql.core.common.Failure; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.AttributeMap; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java index 77c8e7da5d895..7843464650e37 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/PhysicalVerifier.java @@ -7,7 +7,7 @@ package org.elasticsearch.xpack.esql.optimizer; -import org.elasticsearch.xpack.esql.core.common.Failure; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.optimizer.OptimizerRules.PhysicalPlanDependencyCheck; @@ -18,7 +18,7 @@ import java.util.LinkedHashSet; import java.util.Set; -import static org.elasticsearch.xpack.esql.core.common.Failure.fail; +import static org.elasticsearch.xpack.esql.common.Failure.fail; /** Physical plan verifier. */ public final class PhysicalVerifier { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/AddDefaultTopN.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/AddDefaultTopN.java index 28a7ba4bf7084..9208eba740100 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/AddDefaultTopN.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/AddDefaultTopN.java @@ -8,14 +8,14 @@ package org.elasticsearch.xpack.esql.optimizer.rules; import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.TopN; +import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; /** * This adds an explicit TopN node to a plan that only has an OrderBy right before Lucene. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/BooleanFunctionEqualsElimination.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/BooleanFunctionEqualsElimination.java index cf62f9219f3c8..1cdc2c02c8469 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/BooleanFunctionEqualsElimination.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/BooleanFunctionEqualsElimination.java @@ -21,11 +21,10 @@ * This rule must always be placed after {@link LiteralsOnTheRight} * since it looks at TRUE/FALSE literals' existence on the right hand-side of the {@link Equals}/{@link NotEquals} expressions. */ -public final class BooleanFunctionEqualsElimination extends - org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.OptimizerExpressionRule { +public final class BooleanFunctionEqualsElimination extends OptimizerRules.OptimizerExpressionRule { public BooleanFunctionEqualsElimination() { - super(org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.TransformDirection.UP); + super(OptimizerRules.TransformDirection.UP); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/BooleanSimplification.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/BooleanSimplification.java index b01525cc447fc..2a3f7fb9d1244 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/BooleanSimplification.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/BooleanSimplification.java @@ -9,7 +9,7 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; -public final class BooleanSimplification extends org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.BooleanSimplification { +public final class BooleanSimplification extends OptimizerRules.BooleanSimplification { public BooleanSimplification() { super(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineDisjunctionsToIn.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineDisjunctionsToIn.java index c34252300350c..2dc2f0e504303 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineDisjunctionsToIn.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineDisjunctionsToIn.java @@ -35,9 +35,9 @@ * This rule does NOT check for type compatibility as that phase has been * already be verified in the analyzer. */ -public final class CombineDisjunctionsToIn extends org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.OptimizerExpressionRule { +public final class CombineDisjunctionsToIn extends OptimizerRules.OptimizerExpressionRule { public CombineDisjunctionsToIn() { - super(org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.TransformDirection.UP); + super(OptimizerRules.TransformDirection.UP); } protected In createIn(Expression key, List values, ZoneId zoneId) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineEvals.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineEvals.java index 40e9836d0afa1..f8210d06e4439 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineEvals.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineEvals.java @@ -7,10 +7,9 @@ package org.elasticsearch.xpack.esql.optimizer.rules; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.util.CollectionUtils; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; /** * Combine multiple Evals into one in order to reduce the number of nodes in a plan. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineProjections.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineProjections.java index 2070139519ea0..3c0ac9056c8c5 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineProjections.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineProjections.java @@ -15,11 +15,10 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Project; +import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import java.util.ArrayList; import java.util.List; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ConstantFolding.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ConstantFolding.java index f2638333c9601..2178013c42148 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ConstantFolding.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ConstantFolding.java @@ -9,7 +9,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; public final class ConstantFolding extends OptimizerRules.OptimizerExpressionRule { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ConvertStringToByteRef.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ConvertStringToByteRef.java index 384f56d96de73..a1969df3f898a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ConvertStringToByteRef.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ConvertStringToByteRef.java @@ -10,7 +10,6 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; import java.util.ArrayList; import java.util.List; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/DuplicateLimitAfterMvExpand.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/DuplicateLimitAfterMvExpand.java index 6b944bf7adf4f..ab1dc407a7a4a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/DuplicateLimitAfterMvExpand.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/DuplicateLimitAfterMvExpand.java @@ -9,18 +9,17 @@ import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.MvExpand; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.RegexExtract; +import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; public final class DuplicateLimitAfterMvExpand extends OptimizerRules.OptimizerRule { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/FoldNull.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/FoldNull.java index 25ad5e3966f21..6e01811b8527c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/FoldNull.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/FoldNull.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.esql.optimizer.rules; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; public class FoldNull extends OptimizerRules.FoldNull { @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/LiteralsOnTheRight.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/LiteralsOnTheRight.java index 528fe65766972..36d39e0ee1c73 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/LiteralsOnTheRight.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/LiteralsOnTheRight.java @@ -9,7 +9,6 @@ import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryOperator; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; public final class LiteralsOnTheRight extends OptimizerRules.OptimizerExpressionRule> { diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRules.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/OptimizerRules.java similarity index 63% rename from x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRules.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/OptimizerRules.java index ba19a73f91c06..6f6260fd0de27 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRules.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/OptimizerRules.java @@ -4,7 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -package org.elasticsearch.xpack.esql.core.optimizer; +package org.elasticsearch.xpack.esql.optimizer.rules; import org.elasticsearch.common.util.set.Sets; import org.elasticsearch.xpack.esql.core.expression.Alias; @@ -12,36 +12,24 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.Nullability; -import org.elasticsearch.xpack.esql.core.expression.function.Function; import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction; -import org.elasticsearch.xpack.esql.core.expression.function.scalar.SurrogateFunction; import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryPredicate; import org.elasticsearch.xpack.esql.core.expression.predicate.Negatable; import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; -import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogic; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNull; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.In; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.NotEquals; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.rule.Rule; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.ReflectionUtils; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; -import java.time.ZoneId; -import java.util.ArrayList; -import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.LinkedList; import java.util.List; -import java.util.Map; import java.util.Set; import java.util.function.BiFunction; @@ -56,34 +44,6 @@ import static org.elasticsearch.xpack.esql.core.util.CollectionUtils.combine; public final class OptimizerRules { - - /** - * This rule must always be placed after LiteralsOnTheRight, since it looks at TRUE/FALSE literals' existence - * on the right hand-side of the {@link Equals}/{@link NotEquals} expressions. - */ - public static final class BooleanFunctionEqualsElimination extends OptimizerExpressionRule { - - public BooleanFunctionEqualsElimination() { - super(TransformDirection.UP); - } - - @Override - protected Expression rule(BinaryComparison bc) { - if ((bc instanceof Equals || bc instanceof NotEquals) && bc.left() instanceof Function) { - // for expression "==" or "!=" TRUE/FALSE, return the expression itself or its negated variant - - if (TRUE.equals(bc.right())) { - return bc instanceof Equals ? bc.left() : new Not(bc.left().source(), bc.left()); - } - if (FALSE.equals(bc.right())) { - return bc instanceof Equals ? new Not(bc.left().source(), bc.left()) : bc.left(); - } - } - - return bc; - } - } - public static class BooleanSimplification extends OptimizerExpressionRule { public BooleanSimplification() { @@ -220,178 +180,6 @@ protected Expression maybeSimplifyNegatable(Expression e) { } } - /** - * Combine disjunctions on the same field into an In expression. - * This rule looks for both simple equalities: - * 1. a == 1 OR a == 2 becomes a IN (1, 2) - * and combinations of In - * 2. a == 1 OR a IN (2) becomes a IN (1, 2) - * 3. a IN (1) OR a IN (2) becomes a IN (1, 2) - * - * This rule does NOT check for type compatibility as that phase has been - * already be verified in the analyzer. - */ - public static class CombineDisjunctionsToIn extends OptimizerExpressionRule { - public CombineDisjunctionsToIn() { - super(TransformDirection.UP); - } - - @Override - protected Expression rule(Or or) { - Expression e = or; - // look only at equals and In - List exps = splitOr(e); - - Map> found = new LinkedHashMap<>(); - ZoneId zoneId = null; - List ors = new LinkedList<>(); - - for (Expression exp : exps) { - if (exp instanceof Equals eq) { - // consider only equals against foldables - if (eq.right().foldable()) { - found.computeIfAbsent(eq.left(), k -> new LinkedHashSet<>()).add(eq.right()); - } else { - ors.add(exp); - } - if (zoneId == null) { - zoneId = eq.zoneId(); - } - } else if (exp instanceof In in) { - found.computeIfAbsent(in.value(), k -> new LinkedHashSet<>()).addAll(in.list()); - if (zoneId == null) { - zoneId = in.zoneId(); - } - } else { - ors.add(exp); - } - } - - if (found.isEmpty() == false) { - // combine equals alongside the existing ors - final ZoneId finalZoneId = zoneId; - found.forEach( - (k, v) -> { ors.add(v.size() == 1 ? createEquals(k, v, finalZoneId) : createIn(k, new ArrayList<>(v), finalZoneId)); } - ); - - Expression combineOr = combineOr(ors); - // check the result semantically since the result might different in order - // but be actually the same which can trigger a loop - // e.g. a == 1 OR a == 2 OR null --> null OR a in (1,2) --> literalsOnTheRight --> cycle - if (e.semanticEquals(combineOr) == false) { - e = combineOr; - } - } - - return e; - } - - protected Equals createEquals(Expression k, Set v, ZoneId finalZoneId) { - return new Equals(k.source(), k, v.iterator().next(), finalZoneId); - } - - protected In createIn(Expression key, List values, ZoneId zoneId) { - return new In(key.source(), key, values, zoneId); - } - } - - public static class ReplaceSurrogateFunction extends OptimizerExpressionRule { - - public ReplaceSurrogateFunction() { - super(TransformDirection.DOWN); - } - - @Override - protected Expression rule(Expression e) { - if (e instanceof SurrogateFunction) { - e = ((SurrogateFunction) e).substitute(); - } - return e; - } - } - - public abstract static class PruneFilters extends OptimizerRule { - - @Override - protected LogicalPlan rule(Filter filter) { - Expression condition = filter.condition().transformUp(BinaryLogic.class, PruneFilters::foldBinaryLogic); - - if (condition instanceof Literal) { - if (TRUE.equals(condition)) { - return filter.child(); - } - if (FALSE.equals(condition) || Expressions.isNull(condition)) { - return skipPlan(filter); - } - } - - if (condition.equals(filter.condition()) == false) { - return new Filter(filter.source(), filter.child(), condition); - } - return filter; - } - - protected abstract LogicalPlan skipPlan(Filter filter); - - private static Expression foldBinaryLogic(BinaryLogic binaryLogic) { - if (binaryLogic instanceof Or or) { - boolean nullLeft = Expressions.isNull(or.left()); - boolean nullRight = Expressions.isNull(or.right()); - if (nullLeft && nullRight) { - return new Literal(binaryLogic.source(), null, DataType.NULL); - } - if (nullLeft) { - return or.right(); - } - if (nullRight) { - return or.left(); - } - } - if (binaryLogic instanceof And and) { - if (Expressions.isNull(and.left()) || Expressions.isNull(and.right())) { - return new Literal(binaryLogic.source(), null, DataType.NULL); - } - } - return binaryLogic; - } - } - - // NB: it is important to start replacing casts from the bottom to properly replace aliases - public abstract static class PruneCast extends Rule { - - private final Class castType; - - public PruneCast(Class castType) { - this.castType = castType; - } - - @Override - public final LogicalPlan apply(LogicalPlan plan) { - return rule(plan); - } - - protected final LogicalPlan rule(LogicalPlan plan) { - // eliminate redundant casts - return plan.transformExpressionsUp(castType, this::maybePruneCast); - } - - protected abstract Expression maybePruneCast(C cast); - } - - public abstract static class SkipQueryOnLimitZero extends OptimizerRule { - @Override - protected LogicalPlan rule(Limit limit) { - if (limit.limit().foldable()) { - if (Integer.valueOf(0).equals((limit.limit().fold()))) { - return skipPlan(limit); - } - } - return limit; - } - - protected abstract LogicalPlan skipPlan(Limit limit); - } - public static class FoldNull extends OptimizerExpressionRule { public FoldNull() { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PartiallyFoldCase.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PartiallyFoldCase.java index 6b900d91eb061..78435f852982e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PartiallyFoldCase.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PartiallyFoldCase.java @@ -8,10 +8,9 @@ package org.elasticsearch.xpack.esql.optimizer.rules; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; import org.elasticsearch.xpack.esql.expression.function.scalar.conditional.Case; -import static org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.TransformDirection.DOWN; +import static org.elasticsearch.xpack.esql.optimizer.rules.OptimizerRules.TransformDirection.DOWN; /** * Fold the arms of {@code CASE} statements. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateEmptyRelation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateEmptyRelation.java index 8a3281dd7df81..c57e490423ce8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateEmptyRelation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateEmptyRelation.java @@ -13,13 +13,12 @@ import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.Count; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; import org.elasticsearch.xpack.esql.plan.logical.local.LocalSupplier; import org.elasticsearch.xpack.esql.planner.PlannerUtils; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateEquals.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateEquals.java index 5f08363abdbaf..8e5d203942c7a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateEquals.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateEquals.java @@ -35,10 +35,10 @@ * When encountering a different Equals, non-containing {@link Range} or {@link BinaryComparison}, the conjunction becomes false. * When encountering a containing {@link Range}, {@link BinaryComparison} or {@link NotEquals}, these get eliminated by the equality. */ -public final class PropagateEquals extends org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.OptimizerExpressionRule { +public final class PropagateEquals extends OptimizerRules.OptimizerExpressionRule { public PropagateEquals() { - super(org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.TransformDirection.DOWN); + super(OptimizerRules.TransformDirection.DOWN); } public Expression rule(BinaryLogic e) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateEvalFoldables.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateEvalFoldables.java index 872bff80926d6..9231105c9b663 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateEvalFoldables.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateEvalFoldables.java @@ -12,10 +12,10 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.rule.Rule; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; /** * Replace any reference attribute with its source, if it does not affect the result. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateNullable.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateNullable.java index 73ea21f9c8191..08c560c326e81 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateNullable.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateNullable.java @@ -9,7 +9,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; import org.elasticsearch.xpack.esql.expression.function.scalar.nulls.Coalesce; import java.util.ArrayList; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneColumns.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneColumns.java index 9403e3996ec49..baeabb534aa3c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneColumns.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneColumns.java @@ -13,12 +13,12 @@ import org.elasticsearch.xpack.esql.core.expression.EmptyAttribute; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.rule.Rule; import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; import org.elasticsearch.xpack.esql.plan.logical.local.LocalSupplier; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneEmptyPlans.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneEmptyPlans.java index 5c9ef44207366..739d59d8b0df6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneEmptyPlans.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneEmptyPlans.java @@ -7,10 +7,9 @@ package org.elasticsearch.xpack.esql.optimizer.rules; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; public final class PruneEmptyPlans extends OptimizerRules.OptimizerRule { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneFilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneFilters.java index 72df4261663e5..7e9ff7c5f5f02 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneFilters.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneFilters.java @@ -7,15 +7,60 @@ package org.elasticsearch.xpack.esql.optimizer.rules; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Expressions; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.BinaryLogic; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; +import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; -public final class PruneFilters extends OptimizerRules.PruneFilters { +import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; +import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; +public final class PruneFilters extends OptimizerRules.OptimizerRule { @Override - protected LogicalPlan skipPlan(Filter filter) { - return LogicalPlanOptimizer.skipPlan(filter); + protected LogicalPlan rule(Filter filter) { + Expression condition = filter.condition().transformUp(BinaryLogic.class, PruneFilters::foldBinaryLogic); + + if (condition instanceof Literal) { + if (TRUE.equals(condition)) { + return filter.child(); + } + if (FALSE.equals(condition) || Expressions.isNull(condition)) { + return LogicalPlanOptimizer.skipPlan(filter); + } + } + + if (condition.equals(filter.condition()) == false) { + return new Filter(filter.source(), filter.child(), condition); + } + return filter; } + + private static Expression foldBinaryLogic(BinaryLogic binaryLogic) { + if (binaryLogic instanceof Or or) { + boolean nullLeft = Expressions.isNull(or.left()); + boolean nullRight = Expressions.isNull(or.right()); + if (nullLeft && nullRight) { + return new Literal(binaryLogic.source(), null, DataType.NULL); + } + if (nullLeft) { + return or.right(); + } + if (nullRight) { + return or.left(); + } + } + if (binaryLogic instanceof And and) { + if (Expressions.isNull(and.left()) || Expressions.isNull(and.right())) { + return new Literal(binaryLogic.source(), null, DataType.NULL); + } + } + return binaryLogic; + } + } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneLiteralsInOrderBy.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneLiteralsInOrderBy.java index 591cfe043c00d..1fe67c2c435c2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneLiteralsInOrderBy.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneLiteralsInOrderBy.java @@ -8,9 +8,8 @@ package org.elasticsearch.xpack.esql.optimizer.rules; import org.elasticsearch.xpack.esql.core.expression.Order; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import java.util.ArrayList; import java.util.List; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneOrderByBeforeStats.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneOrderByBeforeStats.java index 690bc92b1c338..f2ef524f2c91e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneOrderByBeforeStats.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneOrderByBeforeStats.java @@ -7,16 +7,15 @@ package org.elasticsearch.xpack.esql.optimizer.rules; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.RegexExtract; +import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; public final class PruneOrderByBeforeStats extends OptimizerRules.OptimizerRule { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneRedundantSortClauses.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneRedundantSortClauses.java index 3a9421ee7f159..dc68ae5981429 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneRedundantSortClauses.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PruneRedundantSortClauses.java @@ -9,9 +9,8 @@ import org.elasticsearch.xpack.esql.core.expression.ExpressionSet; import org.elasticsearch.xpack.esql.core.expression.Order; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import java.util.ArrayList; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownAndCombineFilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownAndCombineFilters.java index 647c5c3730157..48013e113fe43 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownAndCombineFilters.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownAndCombineFilters.java @@ -12,18 +12,17 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.RegexExtract; +import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import java.util.ArrayList; import java.util.List; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownAndCombineLimits.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownAndCombineLimits.java index 46fb654d03760..62ecf9ccd09be 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownAndCombineLimits.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownAndCombineLimits.java @@ -8,16 +8,15 @@ package org.elasticsearch.xpack.esql.optimizer.rules; import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.MvExpand; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.RegexExtract; +import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.logical.join.Join; import org.elasticsearch.xpack.esql.plan.logical.join.JoinType; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownAndCombineOrderBy.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownAndCombineOrderBy.java index f01616953427d..286695abda25b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownAndCombineOrderBy.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownAndCombineOrderBy.java @@ -7,10 +7,9 @@ package org.elasticsearch.xpack.esql.optimizer.rules; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Project; public final class PushDownAndCombineOrderBy extends OptimizerRules.OptimizerRule { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownEnrich.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownEnrich.java index f6a0154108f2d..7185f63964c34 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownEnrich.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownEnrich.java @@ -7,10 +7,9 @@ package org.elasticsearch.xpack.esql.optimizer.rules; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; import org.elasticsearch.xpack.esql.plan.logical.Enrich; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import static org.elasticsearch.xpack.esql.core.expression.Expressions.asAttributes; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownEval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownEval.java index b936e5569c950..92c25a60bba77 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownEval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownEval.java @@ -7,10 +7,9 @@ package org.elasticsearch.xpack.esql.optimizer.rules; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import static org.elasticsearch.xpack.esql.core.expression.Expressions.asAttributes; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownRegexExtract.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownRegexExtract.java index f247d0a631b29..d24a61f89dd7f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownRegexExtract.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/PushDownRegexExtract.java @@ -7,9 +7,8 @@ package org.elasticsearch.xpack.esql.optimizer.rules; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.RegexExtract; public final class PushDownRegexExtract extends OptimizerRules.OptimizerRule { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/RemoveStatsOverride.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/RemoveStatsOverride.java index cbcde663f8b14..5592a04e2f813 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/RemoveStatsOverride.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/RemoveStatsOverride.java @@ -8,11 +8,11 @@ package org.elasticsearch.xpack.esql.optimizer.rules; import org.elasticsearch.common.util.set.Sets; -import org.elasticsearch.xpack.esql.core.analyzer.AnalyzerRules; +import org.elasticsearch.xpack.esql.analysis.AnalyzerRules; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.util.ArrayList; import java.util.List; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceAliasingEvalWithProject.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceAliasingEvalWithProject.java index 2bbfeaac965ef..34b75cd89f68c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceAliasingEvalWithProject.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceAliasingEvalWithProject.java @@ -12,11 +12,11 @@ import org.elasticsearch.xpack.esql.core.expression.AttributeMap; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.rule.Rule; import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Project; import java.util.ArrayList; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceLimitAndSortAsTopN.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceLimitAndSortAsTopN.java index ec912735f8451..6394d11bb68c8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceLimitAndSortAsTopN.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceLimitAndSortAsTopN.java @@ -7,10 +7,9 @@ package org.elasticsearch.xpack.esql.optimizer.rules; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.TopN; public final class ReplaceLimitAndSortAsTopN extends OptimizerRules.OptimizerRule { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceLookupWithJoin.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceLookupWithJoin.java index f6c8f4a59a70c..f258ea97bfa33 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceLookupWithJoin.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceLookupWithJoin.java @@ -7,8 +7,7 @@ package org.elasticsearch.xpack.esql.optimizer.rules; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Lookup; import org.elasticsearch.xpack.esql.plan.logical.join.Join; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceOrderByExpressionWithEval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceOrderByExpressionWithEval.java index 476da7476f7fb..02fc98428f14a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceOrderByExpressionWithEval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceOrderByExpressionWithEval.java @@ -10,10 +10,9 @@ import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Order; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Project; import java.util.ArrayList; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceRegexMatch.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceRegexMatch.java index 5cba7349debfd..cc18940e68924 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceRegexMatch.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceRegexMatch.java @@ -15,11 +15,10 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; -public final class ReplaceRegexMatch extends org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.OptimizerExpressionRule< - RegexMatch> { +public final class ReplaceRegexMatch extends OptimizerRules.OptimizerExpressionRule> { public ReplaceRegexMatch() { - super(org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.TransformDirection.DOWN); + super(OptimizerRules.TransformDirection.DOWN); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceStatsAggExpressionWithEval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceStatsAggExpressionWithEval.java index 012d6e307df6c..31b543cd115df 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceStatsAggExpressionWithEval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceStatsAggExpressionWithEval.java @@ -12,14 +12,13 @@ import org.elasticsearch.xpack.esql.core.expression.AttributeMap; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.util.CollectionUtils; import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Project; import java.util.ArrayList; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceStatsNestedExpressionWithEval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceStatsNestedExpressionWithEval.java index 99b0c8047f2ba..0979b745a6607 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceStatsNestedExpressionWithEval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceStatsNestedExpressionWithEval.java @@ -11,13 +11,12 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.grouping.GroupingFunction; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.util.ArrayList; import java.util.HashMap; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceTrivialTypeConversions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceTrivialTypeConversions.java index 2763c71c4bcb6..dc877a99010f8 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceTrivialTypeConversions.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceTrivialTypeConversions.java @@ -9,10 +9,9 @@ import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.expression.function.scalar.convert.AbstractConvertFunction; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; /** * Replace type converting eval with aliasing eval when type change does not occur. diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SetAsOptimized.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SetAsOptimized.java index 168270b68db2d..89d2e7613d2c7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SetAsOptimized.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SetAsOptimized.java @@ -7,8 +7,8 @@ package org.elasticsearch.xpack.esql.optimizer.rules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.rule.Rule; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; public final class SetAsOptimized extends Rule { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SimplifyComparisonsArithmetics.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SimplifyComparisonsArithmetics.java index 151d11fa575ae..4ef069ea16d04 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SimplifyComparisonsArithmetics.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SimplifyComparisonsArithmetics.java @@ -32,12 +32,11 @@ /** * Simplifies arithmetic expressions with BinaryComparisons and fixed point fields, such as: (int + 2) / 3 > 4 => int > 10 */ -public final class SimplifyComparisonsArithmetics extends - org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.OptimizerExpressionRule { +public final class SimplifyComparisonsArithmetics extends OptimizerRules.OptimizerExpressionRule { BiFunction typesCompatible; public SimplifyComparisonsArithmetics(BiFunction typesCompatible) { - super(org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.TransformDirection.UP); + super(OptimizerRules.TransformDirection.UP); this.typesCompatible = typesCompatible; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SkipQueryOnEmptyMappings.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SkipQueryOnEmptyMappings.java index 7ec215db65626..99efacd4ea39a 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SkipQueryOnEmptyMappings.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SkipQueryOnEmptyMappings.java @@ -7,9 +7,8 @@ package org.elasticsearch.xpack.esql.optimizer.rules; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; import org.elasticsearch.xpack.esql.plan.logical.local.LocalSupplier; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SkipQueryOnLimitZero.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SkipQueryOnLimitZero.java index 7cb4f2926045d..199520d648a26 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SkipQueryOnLimitZero.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SkipQueryOnLimitZero.java @@ -7,15 +7,18 @@ package org.elasticsearch.xpack.esql.optimizer.rules; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.optimizer.LogicalPlanOptimizer; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; -public final class SkipQueryOnLimitZero extends OptimizerRules.SkipQueryOnLimitZero { - +public final class SkipQueryOnLimitZero extends OptimizerRules.OptimizerRule { @Override - protected LogicalPlan skipPlan(Limit limit) { - return LogicalPlanOptimizer.skipPlan(limit); + protected LogicalPlan rule(Limit limit) { + if (limit.limit().foldable()) { + if (Integer.valueOf(0).equals((limit.limit().fold()))) { + return LogicalPlanOptimizer.skipPlan(limit); + } + } + return limit; } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SplitInWithFoldableValue.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SplitInWithFoldableValue.java index c762f396a6f43..1d4e90fe0d5ca 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SplitInWithFoldableValue.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SplitInWithFoldableValue.java @@ -10,7 +10,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; import java.util.ArrayList; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SubstituteSpatialSurrogates.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SubstituteSpatialSurrogates.java index c5293785bf1ba..e6501452eeb65 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SubstituteSpatialSurrogates.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SubstituteSpatialSurrogates.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.optimizer.rules; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesFunction; /** diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SubstituteSurrogates.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SubstituteSurrogates.java index fa4049b0e5a3a..2307f6324e942 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SubstituteSurrogates.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SubstituteSurrogates.java @@ -15,13 +15,12 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.expression.SurrogateExpression; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.Rate; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; import org.elasticsearch.xpack.esql.plan.logical.local.LocalSupplier; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/TranslateMetricsAggregate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/TranslateMetricsAggregate.java index 88486bcb864dc..10c7a7325debc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/TranslateMetricsAggregate.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/TranslateMetricsAggregate.java @@ -16,8 +16,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.FromPartial; @@ -27,6 +25,7 @@ import org.elasticsearch.xpack.esql.expression.function.grouping.Bucket; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.util.ArrayList; import java.util.HashMap; @@ -150,7 +149,7 @@ LogicalPlan translate(Aggregate metrics) { if (changed.get()) { secondPassAggs.add(new Alias(alias.source(), alias.name(), null, outerAgg, agg.id())); } else { - var toPartial = new Alias(agg.source(), alias.name(), new ToPartial(agg.source(), af)); + var toPartial = new Alias(agg.source(), alias.name(), new ToPartial(agg.source(), af.field(), af)); var fromPartial = new FromPartial(agg.source(), toPartial.toAttribute(), af); firstPassAggs.add(toPartial); secondPassAggs.add(new Alias(alias.source(), alias.name(), null, fromPartial, alias.id())); @@ -218,7 +217,7 @@ private static Aggregate toStandardAggregate(Aggregate metrics) { final LogicalPlan child = metrics.child().transformDown(EsRelation.class, r -> { var attributes = new ArrayList<>(new AttributeSet(metrics.inputSet())); attributes.removeIf(a -> a.name().equals(MetadataAttribute.TSID_FIELD)); - if (attributes.stream().noneMatch(a -> a.name().equals(MetadataAttribute.TIMESTAMP_FIELD)) == false) { + if (attributes.stream().noneMatch(a -> a.name().equals(MetadataAttribute.TIMESTAMP_FIELD))) { attributes.removeIf(a -> a.name().equals(MetadataAttribute.TIMESTAMP_FIELD)); } return new EsRelation(r.source(), r.index(), new ArrayList<>(attributes), IndexMode.STANDARD); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java index ddf6031445f7f..ebbcfa3b2863b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/EsqlParser.java @@ -20,7 +20,7 @@ import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; import org.elasticsearch.xpack.esql.core.parser.CaseChangingCharStream; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.util.BitSet; import java.util.function.BiFunction; @@ -51,7 +51,7 @@ private T invokeParser( BiFunction result ) { try { - EsqlBaseLexer lexer = new EsqlBaseLexer(new CaseChangingCharStream(CharStreams.fromString(query), false)); + EsqlBaseLexer lexer = new EsqlBaseLexer(new CaseChangingCharStream(CharStreams.fromString(query))); lexer.removeErrorListeners(); lexer.addErrorListener(ERROR_LISTENER); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java index fee51c40a2525..d1e0bdac0bf2f 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/parser/LogicalPlanBuilder.java @@ -16,7 +16,7 @@ import org.elasticsearch.dissect.DissectParser; import org.elasticsearch.index.IndexMode; import org.elasticsearch.xpack.esql.VerificationException; -import org.elasticsearch.xpack.esql.core.common.Failure; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.EmptyAttribute; @@ -31,10 +31,6 @@ import org.elasticsearch.xpack.esql.core.expression.UnresolvedStar; import org.elasticsearch.xpack.esql.core.parser.ParserUtils; import org.elasticsearch.xpack.esql.core.plan.TableIdentifier; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.Holder; @@ -49,11 +45,15 @@ import org.elasticsearch.xpack.esql.plan.logical.EsqlUnresolvedRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.Explain; +import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.Grok; import org.elasticsearch.xpack.esql.plan.logical.InlineStats; import org.elasticsearch.xpack.esql.plan.logical.Keep; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Lookup; import org.elasticsearch.xpack.esql.plan.logical.MvExpand; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Rename; import org.elasticsearch.xpack.esql.plan.logical.Row; import org.elasticsearch.xpack.esql.plan.logical.meta.MetaFunctions; @@ -146,12 +146,30 @@ public PlanFactory visitEvalCommand(EsqlBaseParser.EvalCommandContext ctx) { @Override public PlanFactory visitGrokCommand(EsqlBaseParser.GrokCommandContext ctx) { return p -> { + Source source = source(ctx); String pattern = visitString(ctx.string()).fold().toString(); - Grok result = new Grok(source(ctx), p, expression(ctx.primaryExpression()), Grok.pattern(source(ctx), pattern)); + Grok.Parser grokParser = Grok.pattern(source, pattern); + validateGrokPattern(source, grokParser, pattern); + Grok result = new Grok(source(ctx), p, expression(ctx.primaryExpression()), grokParser); return result; }; } + private void validateGrokPattern(Source source, Grok.Parser grokParser, String pattern) { + Map definedAttributes = new HashMap<>(); + for (Attribute field : grokParser.extractedFields()) { + String name = field.name(); + DataType type = field.dataType(); + DataType prev = definedAttributes.put(name, type); + if (prev != null) { + throw new ParsingException( + source, + "Invalid GROK pattern [" + pattern + "]: the attribute [" + name + "] is defined multiple times with different types" + ); + } + } + } + @Override public PlanFactory visitDissectCommand(EsqlBaseParser.DissectCommandContext ctx) { return p -> { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java index bc7282857dbbe..5ab483e60d7b0 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Aggregate.java @@ -14,8 +14,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/BinaryPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/BinaryPlan.java similarity index 95% rename from x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/BinaryPlan.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/BinaryPlan.java index 051c3d7946b4b..579b67eb891ac 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/BinaryPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/BinaryPlan.java @@ -4,7 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -package org.elasticsearch.xpack.esql.core.plan.logical; +package org.elasticsearch.xpack.esql.plan.logical; import org.elasticsearch.xpack.esql.core.tree.Source; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Dissect.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Dissect.java index 1307d1870bba4..c0c564b1b36eb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Dissect.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Dissect.java @@ -10,8 +10,6 @@ import org.elasticsearch.dissect.DissectParser; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Drop.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Drop.java index 2946287ae21f0..d1c5d70018d91 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Drop.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Drop.java @@ -9,8 +9,6 @@ import org.elasticsearch.xpack.esql.core.capabilities.Resolvables; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java index f418ab5da1c9d..a4d553eae4749 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Enrich.java @@ -14,8 +14,6 @@ import org.elasticsearch.xpack.esql.core.expression.EmptyAttribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/EsRelation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/EsRelation.java index 08916c14e91bf..382838a5968cc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/EsRelation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/EsRelation.java @@ -10,7 +10,6 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.index.EsIndex; -import org.elasticsearch.xpack.esql.core.plan.logical.LeafPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.NodeUtils; import org.elasticsearch.xpack.esql.core.tree.Source; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/EsqlAggregate.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/EsqlAggregate.java index 7f16ecd24dc1a..cc72823507f02 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/EsqlAggregate.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/EsqlAggregate.java @@ -11,7 +11,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Eval.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Eval.java index bfe11c3d33d87..20117a873c143 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Eval.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Eval.java @@ -10,8 +10,6 @@ import org.elasticsearch.xpack.esql.core.capabilities.Resolvables; import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Explain.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Explain.java index 86f3e0bdf349a..8d2640a43f38c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Explain.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Explain.java @@ -9,8 +9,6 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; -import org.elasticsearch.xpack.esql.core.plan.logical.LeafPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/Filter.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Filter.java similarity index 97% rename from x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/Filter.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Filter.java index a09ffb3e07c96..46fafe57e7d26 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/Filter.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Filter.java @@ -4,7 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -package org.elasticsearch.xpack.esql.core.plan.logical; +package org.elasticsearch.xpack.esql.plan.logical; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Grok.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Grok.java index 0c1c400f3ab4d..963fd318f814c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Grok.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Grok.java @@ -15,8 +15,6 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -32,7 +30,7 @@ public class Grok extends RegexExtract { public record Parser(String pattern, org.elasticsearch.grok.Grok grok) { - private List extractedFields() { + public List extractedFields() { return grok.captureConfig() .stream() .sorted(Comparator.comparing(GrokCaptureConfig::name)) diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/InlineStats.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/InlineStats.java index 4e7dc70904189..46ec56223384c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/InlineStats.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/InlineStats.java @@ -12,8 +12,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Keep.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Keep.java index a4e733437e80f..c1c8c9aff5ca6 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Keep.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Keep.java @@ -8,7 +8,6 @@ package org.elasticsearch.xpack.esql.plan.logical; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/LeafPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LeafPlan.java similarity index 92% rename from x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/LeafPlan.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LeafPlan.java index 4def8356b316a..d21b61a81cd9e 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/LeafPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LeafPlan.java @@ -4,7 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -package org.elasticsearch.xpack.esql.core.plan.logical; +package org.elasticsearch.xpack.esql.plan.logical; import org.elasticsearch.xpack.esql.core.tree.Source; diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/Limit.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Limit.java similarity index 96% rename from x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/Limit.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Limit.java index 610572f1e73ed..df5e1cf23275c 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/Limit.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Limit.java @@ -4,7 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -package org.elasticsearch.xpack.esql.core.plan.logical; +package org.elasticsearch.xpack.esql.plan.logical; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/LogicalPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java similarity index 97% rename from x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/LogicalPlan.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java index 56e09b4e1189a..0397183c6a6c3 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/LogicalPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/LogicalPlan.java @@ -4,7 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -package org.elasticsearch.xpack.esql.core.plan.logical; +package org.elasticsearch.xpack.esql.plan.logical; import org.elasticsearch.xpack.esql.core.capabilities.Resolvable; import org.elasticsearch.xpack.esql.core.capabilities.Resolvables; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Lookup.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Lookup.java index f28a1d11a5990..6893935f20b5b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Lookup.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Lookup.java @@ -11,8 +11,6 @@ import org.elasticsearch.xpack.esql.core.capabilities.Resolvables; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/MvExpand.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/MvExpand.java index 869d8d7dc3a26..5e9dca26a6863 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/MvExpand.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/MvExpand.java @@ -9,8 +9,6 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/OrderBy.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/OrderBy.java similarity index 96% rename from x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/OrderBy.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/OrderBy.java index c13b3a028f0e8..68d089980074c 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/OrderBy.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/OrderBy.java @@ -4,7 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -package org.elasticsearch.xpack.esql.core.plan.logical; +package org.elasticsearch.xpack.esql.plan.logical; import org.elasticsearch.xpack.esql.core.capabilities.Resolvables; import org.elasticsearch.xpack.esql.core.expression.Order; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Project.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Project.java index fe28ddcc43b40..d3896b1dfc844 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Project.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Project.java @@ -10,8 +10,6 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.function.Functions; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/RegexExtract.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/RegexExtract.java index 5bf45fc0f61ad..649173f11dfaf 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/RegexExtract.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/RegexExtract.java @@ -9,8 +9,6 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.Source; import java.util.List; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Rename.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Rename.java index 7d99c566aa0c7..5e4b45d7127fe 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Rename.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Rename.java @@ -8,8 +8,6 @@ package org.elasticsearch.xpack.esql.plan.logical; import org.elasticsearch.xpack.esql.core.expression.Alias; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Row.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Row.java index 9af3e08a6734b..30e16d9e1b227 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Row.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/Row.java @@ -11,8 +11,6 @@ import org.elasticsearch.xpack.esql.core.expression.Alias; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expressions; -import org.elasticsearch.xpack.esql.core.plan.logical.LeafPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/TopN.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/TopN.java index ac576eaa2cb96..227d7785804d4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/TopN.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/TopN.java @@ -10,8 +10,6 @@ import org.elasticsearch.xpack.esql.core.capabilities.Resolvables; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Order; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/UnaryPlan.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/UnaryPlan.java similarity index 96% rename from x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/UnaryPlan.java rename to x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/UnaryPlan.java index 75ce38127394e..ea9a760ef5dc4 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/plan/logical/UnaryPlan.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/UnaryPlan.java @@ -4,7 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -package org.elasticsearch.xpack.esql.core.plan.logical; +package org.elasticsearch.xpack.esql.plan.logical; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.tree.Source; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/UnresolvedRelation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/UnresolvedRelation.java index eb6627bbdd0f8..af19bc87f2c54 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/UnresolvedRelation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/UnresolvedRelation.java @@ -9,7 +9,6 @@ import org.elasticsearch.xpack.esql.core.capabilities.Unresolvable; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.plan.TableIdentifier; -import org.elasticsearch.xpack.esql.core.plan.logical.LeafPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/join/Join.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/join/Join.java index d6d328686d8f1..79278995b29bd 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/join/Join.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/join/Join.java @@ -12,12 +12,12 @@ import org.elasticsearch.xpack.esql.core.expression.Expressions; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; -import org.elasticsearch.xpack.esql.core.plan.logical.BinaryPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import org.elasticsearch.xpack.esql.plan.logical.BinaryPlan; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.io.IOException; import java.util.ArrayList; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/local/EsqlProject.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/local/EsqlProject.java index 03a9c2b68b327..e359c6f928f7c 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/local/EsqlProject.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/local/EsqlProject.java @@ -8,10 +8,10 @@ package org.elasticsearch.xpack.esql.plan.logical.local; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Project; import java.util.List; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/local/LocalRelation.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/local/LocalRelation.java index 862098621e9ee..195eb3b6304e4 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/local/LocalRelation.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/local/LocalRelation.java @@ -7,11 +7,11 @@ package org.elasticsearch.xpack.esql.plan.logical.local; import org.elasticsearch.xpack.esql.core.expression.Attribute; -import org.elasticsearch.xpack.esql.core.plan.logical.LeafPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import org.elasticsearch.xpack.esql.plan.logical.LeafPlan; import java.io.IOException; import java.util.List; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/meta/MetaFunctions.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/meta/MetaFunctions.java index f137cf392f8ad..9ac9ccdf2a876 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/meta/MetaFunctions.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/meta/MetaFunctions.java @@ -11,11 +11,11 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; -import org.elasticsearch.xpack.esql.core.plan.logical.LeafPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; +import org.elasticsearch.xpack.esql.plan.logical.LeafPlan; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.util.ArrayList; import java.util.Arrays; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/show/ShowInfo.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/show/ShowInfo.java index 4867d8ca77a39..6e98df32580ae 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/show/ShowInfo.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/logical/show/ShowInfo.java @@ -11,10 +11,10 @@ import org.elasticsearch.Build; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; -import org.elasticsearch.xpack.esql.core.plan.logical.LeafPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.plan.logical.LeafPlan; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.util.ArrayList; import java.util.List; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/FragmentExec.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/FragmentExec.java index 95cd732eabd45..5c01658760632 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/FragmentExec.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plan/physical/FragmentExec.java @@ -9,9 +9,9 @@ import org.elasticsearch.index.query.QueryBuilder; import org.elasticsearch.xpack.esql.core.expression.Attribute; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.util.List; import java.util.Objects; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java index 91433e42033c5..87775d5048752 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/AggregateMapper.java @@ -146,6 +146,8 @@ private static Stream, Tuple>> typeAndNames(Class List extraConfigs = List.of(""); if (NumericAggregate.class.isAssignableFrom(clazz)) { types = NUMERIC; + } else if (Max.class.isAssignableFrom(clazz) || Min.class.isAssignableFrom(clazz)) { + types = List.of("Boolean", "Int", "Long", "Double"); } else if (clazz == Count.class) { types = List.of(""); // no extra type distinction } else if (SpatialAggregateFunction.class.isAssignableFrom(clazz)) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java index 9e1e1a50fe8f0..8611d2c6fa9fb 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/EsPhysicalOperationProviders.java @@ -233,8 +233,9 @@ public final Operator.OperatorFactory ordinalGroupingOperatorFactory( // The grouping-by values are ready, let's group on them directly. // Costin: why are they ready and not already exposed in the layout? boolean isUnsupported = attrSource.dataType() == DataType.UNSUPPORTED; + var unionTypes = findUnionTypes(attrSource); return new OrdinalsGroupingOperator.OrdinalsGroupingOperatorFactory( - shardIdx -> shardContexts.get(shardIdx).blockLoader(attrSource.name(), isUnsupported, NONE), + shardIdx -> getBlockLoaderFor(shardIdx, attrSource.name(), isUnsupported, NONE, unionTypes), vsShardContexts, groupElementType, docChannel, @@ -434,12 +435,13 @@ public StoredFieldsSpec rowStrideStoredFieldSpec() { @Override public boolean supportsOrdinals() { - return delegate.supportsOrdinals(); + // Fields with mismatching types cannot use ordinals for uniqueness determination, but must convert the values first + return false; } @Override - public SortedSetDocValues ordinals(LeafReaderContext context) throws IOException { - return delegate.ordinals(context); + public SortedSetDocValues ordinals(LeafReaderContext context) { + throw new IllegalArgumentException("Ordinals are not supported for type conversion"); } @Override diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/Mapper.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/Mapper.java index 5ba2a205d52d0..84ed4663496de 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/Mapper.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/Mapper.java @@ -9,23 +9,23 @@ import org.elasticsearch.common.lucene.BytesRefs; import org.elasticsearch.xpack.esql.EsqlIllegalArgumentException; -import org.elasticsearch.xpack.esql.core.plan.logical.BinaryPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; +import org.elasticsearch.xpack.esql.plan.logical.BinaryPlan; import org.elasticsearch.xpack.esql.plan.logical.Dissect; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.Grok; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.MvExpand; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.Row; import org.elasticsearch.xpack.esql.plan.logical.TopN; +import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.logical.join.Join; import org.elasticsearch.xpack.esql.plan.logical.join.JoinConfig; import org.elasticsearch.xpack.esql.plan.logical.join.JoinType; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java index a729cec893126..d9f073d952a37 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/planner/PlannerUtils.java @@ -21,11 +21,6 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.Holder; @@ -36,7 +31,12 @@ import org.elasticsearch.xpack.esql.optimizer.LocalPhysicalPlanOptimizer; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.TopN; +import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.physical.AggregateExec; import org.elasticsearch.xpack.esql.plan.physical.EsSourceExec; import org.elasticsearch.xpack.esql.plan.physical.EstimatesRowSize; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java new file mode 100644 index 0000000000000..01d50d505f7f2 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plugin; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.RefCountingListener; +import org.elasticsearch.compute.operator.DriverProfile; +import org.elasticsearch.compute.operator.FailureCollector; +import org.elasticsearch.compute.operator.ResponseHeadersCollector; +import org.elasticsearch.core.Releasable; +import org.elasticsearch.logging.LogManager; +import org.elasticsearch.logging.Logger; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.transport.TransportService; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * A variant of {@link RefCountingListener} with the following differences: + * 1. Automatically cancels sub tasks on failure. + * 2. Collects driver profiles from sub tasks. + * 3. Collects response headers from sub tasks, specifically warnings emitted during compute + * 4. Collects failures and returns the most appropriate exception to the caller. + */ +final class ComputeListener implements Releasable { + private static final Logger LOGGER = LogManager.getLogger(ComputeService.class); + + private final RefCountingListener refs; + private final FailureCollector failureCollector = new FailureCollector(); + private final AtomicBoolean cancelled = new AtomicBoolean(); + private final CancellableTask task; + private final TransportService transportService; + private final List collectedProfiles; + private final ResponseHeadersCollector responseHeaders; + + ComputeListener(TransportService transportService, CancellableTask task, ActionListener delegate) { + this.transportService = transportService; + this.task = task; + this.responseHeaders = new ResponseHeadersCollector(transportService.getThreadPool().getThreadContext()); + this.collectedProfiles = Collections.synchronizedList(new ArrayList<>()); + this.refs = new RefCountingListener(1, ActionListener.wrap(ignored -> { + responseHeaders.finish(); + var result = new ComputeResponse(collectedProfiles.isEmpty() ? List.of() : collectedProfiles.stream().toList()); + delegate.onResponse(result); + }, e -> delegate.onFailure(failureCollector.getFailure()))); + } + + /** + * Acquires a new listener that doesn't collect result + */ + ActionListener acquireAvoid() { + return refs.acquire().delegateResponse((l, e) -> { + failureCollector.unwrapAndCollect(e); + try { + if (cancelled.compareAndSet(false, true)) { + LOGGER.debug("cancelling ESQL task {} on failure", task); + transportService.getTaskManager().cancelTaskAndDescendants(task, "cancelled on failure", false, ActionListener.noop()); + } + } finally { + l.onFailure(e); + } + }); + } + + /** + * Acquires a new listener that collects compute result. This listener will also collects warnings emitted during compute + */ + ActionListener acquireCompute() { + return acquireAvoid().map(resp -> { + responseHeaders.collect(); + var profiles = resp.getProfiles(); + if (profiles != null && profiles.isEmpty() == false) { + collectedProfiles.addAll(profiles); + } + return null; + }); + } + + @Override + public void close() { + refs.close(); + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index 4ebc4af258134..673e320e5106b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -27,9 +27,7 @@ import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.Driver; -import org.elasticsearch.compute.operator.DriverProfile; import org.elasticsearch.compute.operator.DriverTaskRunner; -import org.elasticsearch.compute.operator.ResponseHeadersCollector; import org.elasticsearch.compute.operator.exchange.ExchangeService; import org.elasticsearch.compute.operator.exchange.ExchangeSink; import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler; @@ -72,6 +70,7 @@ import org.elasticsearch.xpack.esql.planner.LocalExecutionPlanner; import org.elasticsearch.xpack.esql.planner.PlannerUtils; import org.elasticsearch.xpack.esql.session.EsqlConfiguration; +import org.elasticsearch.xpack.esql.session.Result; import java.util.ArrayList; import java.util.Collections; @@ -81,7 +80,6 @@ import java.util.Set; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.function.Supplier; import static org.elasticsearch.xpack.esql.plugin.EsqlPlugin.ESQL_WORKER_THREAD_POOL_NAME; @@ -89,8 +87,6 @@ * Computes the result of a {@link PhysicalPlan}. */ public class ComputeService { - public record Result(List pages, List profiles) {} - private static final Logger LOGGER = LogManager.getLogger(ComputeService.class); private final SearchService searchService; private final BigArrays bigArrays; @@ -172,13 +168,16 @@ public void execute( null, null ); - runCompute( - rootTask, - computeContext, - coordinatorPlan, - listener.map(driverProfiles -> new Result(collectedPages, driverProfiles)) - ); - return; + try ( + var computeListener = new ComputeListener( + transportService, + rootTask, + listener.map(r -> new Result(physicalPlan.output(), collectedPages, r.getProfiles())) + ) + ) { + runCompute(rootTask, computeContext, coordinatorPlan, computeListener.acquireCompute()); + return; + } } else { if (clusterToConcreteIndices.values().stream().allMatch(v -> v.indices().length == 0)) { var error = "expected concrete indices with data node plan but got empty; data node plan " + dataNodePlan; @@ -191,31 +190,25 @@ public void execute( .groupIndices(SearchRequest.DEFAULT_INDICES_OPTIONS, PlannerUtils.planOriginalIndices(physicalPlan)); var localOriginalIndices = clusterToOriginalIndices.remove(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY); var localConcreteIndices = clusterToConcreteIndices.remove(RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY); - final var responseHeadersCollector = new ResponseHeadersCollector(transportService.getThreadPool().getThreadContext()); - listener = ActionListener.runBefore(listener, responseHeadersCollector::finish); - final AtomicBoolean cancelled = new AtomicBoolean(); - final List collectedProfiles = configuration.profile() ? Collections.synchronizedList(new ArrayList<>()) : List.of(); final var exchangeSource = new ExchangeSourceHandler( queryPragmas.exchangeBufferSize(), transportService.getThreadPool().executor(ThreadPool.Names.SEARCH) ); try ( Releasable ignored = exchangeSource.addEmptySink(); - RefCountingListener refs = new RefCountingListener(listener.map(unused -> new Result(collectedPages, collectedProfiles))) + var computeListener = new ComputeListener( + transportService, + rootTask, + listener.map(r -> new Result(physicalPlan.output(), collectedPages, r.getProfiles())) + ) ) { // run compute on the coordinator - exchangeSource.addCompletionListener(refs.acquire()); + exchangeSource.addCompletionListener(computeListener.acquireAvoid()); runCompute( rootTask, new ComputeContext(sessionId, RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, List.of(), configuration, exchangeSource, null), coordinatorPlan, - cancelOnFailure(rootTask, cancelled, refs.acquire()).map(driverProfiles -> { - responseHeadersCollector.collect(); - if (configuration.profile()) { - collectedProfiles.addAll(driverProfiles); - } - return null; - }) + computeListener.acquireCompute() ); // starts computes on data nodes on the main cluster if (localConcreteIndices != null && localConcreteIndices.indices().length > 0) { @@ -228,17 +221,10 @@ public void execute( Set.of(localConcreteIndices.indices()), localOriginalIndices.indices(), exchangeSource, - ActionListener.releaseAfter(refs.acquire(), exchangeSource.addEmptySink()), - () -> cancelOnFailure(rootTask, cancelled, refs.acquire()).map(response -> { - responseHeadersCollector.collect(); - if (configuration.profile()) { - collectedProfiles.addAll(response.getProfiles()); - } - return null; - }) + computeListener ); } - // starts computes on remote cluster + // starts computes on remote clusters startComputeOnRemoteClusters( sessionId, rootTask, @@ -246,13 +232,7 @@ public void execute( dataNodePlan, exchangeSource, getRemoteClusters(clusterToConcreteIndices, clusterToOriginalIndices), - () -> cancelOnFailure(rootTask, cancelled, refs.acquire()).map(response -> { - responseHeadersCollector.collect(); - if (configuration.profile()) { - collectedProfiles.addAll(response.getProfiles()); - } - return null; - }) + computeListener ); } } @@ -288,8 +268,7 @@ private void startComputeOnDataNodes( Set concreteIndices, String[] originalIndices, ExchangeSourceHandler exchangeSource, - ActionListener parentListener, - Supplier> dataNodeListenerSupplier + ComputeListener computeListener ) { var planWithReducer = configuration.pragmas().nodeLevelReduction() == false ? dataNodePlan @@ -303,12 +282,12 @@ private void startComputeOnDataNodes( // Since it's used only for @timestamp, it is relatively safe to assume it's not needed // but it would be better to have a proper impl. QueryBuilder requestFilter = PlannerUtils.requestFilter(planWithReducer, x -> true); + var lookupListener = ActionListener.releaseAfter(computeListener.acquireAvoid(), exchangeSource.addEmptySink()); lookupDataNodes(parentTask, clusterAlias, requestFilter, concreteIndices, originalIndices, ActionListener.wrap(dataNodes -> { - try (RefCountingRunnable refs = new RefCountingRunnable(() -> parentListener.onResponse(null))) { + try (RefCountingListener refs = new RefCountingListener(lookupListener)) { // For each target node, first open a remote exchange on the remote node, then link the exchange source to // the new remote exchange sink, and initialize the computation on the target node via data-node-request. for (DataNode node : dataNodes) { - var dataNodeListener = ActionListener.releaseAfter(dataNodeListenerSupplier.get(), refs.acquire()); var queryPragmas = configuration.pragmas(); ExchangeService.openExchange( transportService, @@ -316,9 +295,10 @@ private void startComputeOnDataNodes( sessionId, queryPragmas.exchangeBufferSize(), esqlExecutor, - dataNodeListener.delegateFailureAndWrap((delegate, unused) -> { + refs.acquire().delegateFailureAndWrap((l, unused) -> { var remoteSink = exchangeService.newRemoteSink(parentTask, sessionId, transportService, node.connection); exchangeSource.addRemoteSink(remoteSink, queryPragmas.concurrentExchangeClients()); + var dataNodeListener = ActionListener.runBefore(computeListener.acquireCompute(), () -> l.onResponse(null)); transportService.sendChildRequest( node.connection, DATA_ACTION_NAME, @@ -332,13 +312,13 @@ private void startComputeOnDataNodes( ), parentTask, TransportRequestOptions.EMPTY, - new ActionListenerResponseHandler<>(delegate, ComputeResponse::new, esqlExecutor) + new ActionListenerResponseHandler<>(dataNodeListener, ComputeResponse::new, esqlExecutor) ); }) ); } } - }, parentListener::onFailure)); + }, lookupListener::onFailure)); } private void startComputeOnRemoteClusters( @@ -348,19 +328,19 @@ private void startComputeOnRemoteClusters( PhysicalPlan plan, ExchangeSourceHandler exchangeSource, List clusters, - Supplier> listener + ComputeListener computeListener ) { - try (RefCountingRunnable refs = new RefCountingRunnable(exchangeSource.addEmptySink()::close)) { + var queryPragmas = configuration.pragmas(); + var linkExchangeListeners = ActionListener.releaseAfter(computeListener.acquireAvoid(), exchangeSource.addEmptySink()); + try (RefCountingListener refs = new RefCountingListener(linkExchangeListeners)) { for (RemoteCluster cluster : clusters) { - var targetNodeListener = ActionListener.releaseAfter(listener.get(), refs.acquire()); - var queryPragmas = configuration.pragmas(); ExchangeService.openExchange( transportService, cluster.connection, sessionId, queryPragmas.exchangeBufferSize(), esqlExecutor, - targetNodeListener.delegateFailureAndWrap((l, unused) -> { + refs.acquire().delegateFailureAndWrap((l, unused) -> { var remoteSink = exchangeService.newRemoteSink(rootTask, sessionId, transportService, cluster.connection); exchangeSource.addRemoteSink(remoteSink, queryPragmas.concurrentExchangeClients()); var clusterRequest = new ClusterComputeRequest( @@ -371,13 +351,14 @@ private void startComputeOnRemoteClusters( cluster.concreteIndices, cluster.originalIndices ); + var clusterListener = ActionListener.runBefore(computeListener.acquireCompute(), () -> l.onResponse(null)); transportService.sendChildRequest( cluster.connection, CLUSTER_ACTION_NAME, clusterRequest, rootTask, TransportRequestOptions.EMPTY, - new ActionListenerResponseHandler<>(l, ComputeResponse::new, esqlExecutor) + new ActionListenerResponseHandler<>(clusterListener, ComputeResponse::new, esqlExecutor) ); }) ); @@ -385,17 +366,7 @@ private void startComputeOnRemoteClusters( } } - private ActionListener cancelOnFailure(CancellableTask task, AtomicBoolean cancelled, ActionListener listener) { - return listener.delegateResponse((l, e) -> { - l.onFailure(e); - if (cancelled.compareAndSet(false, true)) { - LOGGER.debug("cancelling ESQL task {} on failure", task); - transportService.getTaskManager().cancelTaskAndDescendants(task, "cancelled", false, ActionListener.noop()); - } - }); - } - - void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, ActionListener> listener) { + void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, ActionListener listener) { listener = ActionListener.runBefore(listener, () -> Releasables.close(context.searchContexts)); List contexts = new ArrayList<>(context.searchContexts.size()); for (int i = 0; i < context.searchContexts.size(); i++) { @@ -445,9 +416,10 @@ void runCompute(CancellableTask task, ComputeContext context, PhysicalPlan plan, } ActionListener listenerCollectingStatus = listener.map(ignored -> { if (context.configuration.profile()) { - return drivers.stream().map(Driver::profile).toList(); + return new ComputeResponse(drivers.stream().map(Driver::profile).toList()); + } else { + return new ComputeResponse(List.of()); } - return null; }); listenerCollectingStatus = ActionListener.releaseAfter(listenerCollectingStatus, () -> Releasables.close(drivers)); driverRunner.executeDrivers( @@ -612,8 +584,7 @@ private class DataNodeRequestExecutor { private final DataNodeRequest request; private final CancellableTask parentTask; private final ExchangeSinkHandler exchangeSink; - private final ActionListener listener; - private final List driverProfiles; + private final ComputeListener computeListener; private final int maxConcurrentShards; private final ExchangeSink blockingSink; // block until we have completed on all shards or the coordinator has enough data @@ -622,14 +593,12 @@ private class DataNodeRequestExecutor { CancellableTask parentTask, ExchangeSinkHandler exchangeSink, int maxConcurrentShards, - List driverProfiles, - ActionListener listener + ComputeListener computeListener ) { this.request = request; this.parentTask = parentTask; this.exchangeSink = exchangeSink; - this.listener = listener; - this.driverProfiles = driverProfiles; + this.computeListener = computeListener; this.maxConcurrentShards = maxConcurrentShards; this.blockingSink = exchangeSink.createExchangeSink(); } @@ -647,40 +616,46 @@ private void runBatch(int startBatchIndex) { final var sessionId = request.sessionId(); final int endBatchIndex = Math.min(startBatchIndex + maxConcurrentShards, request.shardIds().size()); List shardIds = request.shardIds().subList(startBatchIndex, endBatchIndex); + ActionListener batchListener = new ActionListener<>() { + final ActionListener ref = computeListener.acquireCompute(); + + @Override + public void onResponse(ComputeResponse result) { + try { + onBatchCompleted(endBatchIndex); + } finally { + ref.onResponse(result); + } + } + + @Override + public void onFailure(Exception e) { + try { + exchangeService.finishSinkHandler(request.sessionId(), e); + } finally { + ref.onFailure(e); + } + } + }; acquireSearchContexts(clusterAlias, shardIds, configuration, request.aliasFilters(), ActionListener.wrap(searchContexts -> { assert ThreadPool.assertCurrentThreadPool(ThreadPool.Names.SEARCH, ESQL_WORKER_THREAD_POOL_NAME); var computeContext = new ComputeContext(sessionId, clusterAlias, searchContexts, configuration, null, exchangeSink); - runCompute( - parentTask, - computeContext, - request.plan(), - ActionListener.wrap(profiles -> onBatchCompleted(endBatchIndex, profiles), this::onFailure) - ); - }, this::onFailure)); + runCompute(parentTask, computeContext, request.plan(), batchListener); + }, batchListener::onFailure)); } - private void onBatchCompleted(int lastBatchIndex, List batchProfiles) { - if (request.configuration().profile()) { - driverProfiles.addAll(batchProfiles); - } + private void onBatchCompleted(int lastBatchIndex) { if (lastBatchIndex < request.shardIds().size() && exchangeSink.isFinished() == false) { runBatch(lastBatchIndex); } else { - blockingSink.finish(); // don't return until all pages are fetched + var completionListener = computeListener.acquireAvoid(); exchangeSink.addCompletionListener( - ContextPreservingActionListener.wrapPreservingContext( - ActionListener.runBefore(listener, () -> exchangeService.finishSinkHandler(request.sessionId(), null)), - transportService.getThreadPool().getThreadContext() - ) + ActionListener.runAfter(completionListener, () -> exchangeService.finishSinkHandler(request.sessionId(), null)) ); + blockingSink.finish(); } } - - private void onFailure(Exception e) { - exchangeService.finishSinkHandler(request.sessionId(), e); - listener.onFailure(e); - } } private void runComputeOnDataNode( @@ -688,17 +663,10 @@ private void runComputeOnDataNode( String externalId, PhysicalPlan reducePlan, DataNodeRequest request, - ActionListener listener + ComputeListener computeListener ) { - final List collectedProfiles = request.configuration().profile() - ? Collections.synchronizedList(new ArrayList<>()) - : List.of(); - final var responseHeadersCollector = new ResponseHeadersCollector(transportService.getThreadPool().getThreadContext()); - final RefCountingListener listenerRefs = new RefCountingListener( - ActionListener.runBefore(listener.map(unused -> new ComputeResponse(collectedProfiles)), responseHeadersCollector::finish) - ); + var parentListener = computeListener.acquireAvoid(); try { - final AtomicBoolean cancelled = new AtomicBoolean(); // run compute with target shards var internalSink = exchangeService.createSinkHandler(request.sessionId(), request.pragmas().exchangeBufferSize()); DataNodeRequestExecutor dataNodeRequestExecutor = new DataNodeRequestExecutor( @@ -706,17 +674,16 @@ private void runComputeOnDataNode( task, internalSink, request.configuration().pragmas().maxConcurrentShardsPerNode(), - collectedProfiles, - ActionListener.runBefore(cancelOnFailure(task, cancelled, listenerRefs.acquire()), responseHeadersCollector::collect) + computeListener ); dataNodeRequestExecutor.start(); // run the node-level reduction var externalSink = exchangeService.getSinkHandler(externalId); task.addListener(() -> exchangeService.finishSinkHandler(externalId, new TaskCancelledException(task.getReasonCancelled()))); var exchangeSource = new ExchangeSourceHandler(1, esqlExecutor); - exchangeSource.addCompletionListener(listenerRefs.acquire()); + exchangeSource.addCompletionListener(computeListener.acquireAvoid()); exchangeSource.addRemoteSink(internalSink::fetchPageAsync, 1); - ActionListener reductionListener = cancelOnFailure(task, cancelled, listenerRefs.acquire()); + ActionListener reductionListener = computeListener.acquireCompute(); runCompute( task, new ComputeContext( @@ -728,26 +695,22 @@ private void runComputeOnDataNode( externalSink ), reducePlan, - ActionListener.wrap(driverProfiles -> { - responseHeadersCollector.collect(); - if (request.configuration().profile()) { - collectedProfiles.addAll(driverProfiles); - } + ActionListener.wrap(resp -> { // don't return until all pages are fetched - externalSink.addCompletionListener( - ActionListener.runBefore(reductionListener, () -> exchangeService.finishSinkHandler(externalId, null)) - ); + externalSink.addCompletionListener(ActionListener.running(() -> { + exchangeService.finishSinkHandler(externalId, null); + reductionListener.onResponse(resp); + })); }, e -> { exchangeService.finishSinkHandler(externalId, e); reductionListener.onFailure(e); }) ); + parentListener.onResponse(null); } catch (Exception e) { exchangeService.finishSinkHandler(externalId, e); exchangeService.finishSinkHandler(request.sessionId(), e); - listenerRefs.acquire().onFailure(e); - } finally { - listenerRefs.close(); + parentListener.onFailure(e); } } @@ -784,7 +747,9 @@ public void messageReceived(DataNodeRequest request, TransportChannel channel, T request.aliasFilters(), request.plan() ); - runComputeOnDataNode((CancellableTask) task, sessionId, reducePlan, request, listener); + try (var computeListener = new ComputeListener(transportService, (CancellableTask) task, listener)) { + runComputeOnDataNode((CancellableTask) task, sessionId, reducePlan, request, computeListener); + } } } @@ -798,16 +763,18 @@ public void messageReceived(ClusterComputeRequest request, TransportChannel chan listener.onFailure(new IllegalStateException("expected exchange sink for a remote compute; got " + request.plan())); return; } - runComputeOnRemoteCluster( - request.clusterAlias(), - request.sessionId(), - (CancellableTask) task, - request.configuration(), - (ExchangeSinkExec) request.plan(), - Set.of(request.indices()), - request.originalIndices(), - listener - ); + try (var computeListener = new ComputeListener(transportService, (CancellableTask) task, listener)) { + runComputeOnRemoteCluster( + request.clusterAlias(), + request.sessionId(), + (CancellableTask) task, + request.configuration(), + (ExchangeSinkExec) request.plan(), + Set.of(request.indices()), + request.originalIndices(), + computeListener + ); + } } } @@ -828,28 +795,20 @@ void runComputeOnRemoteCluster( ExchangeSinkExec plan, Set concreteIndices, String[] originalIndices, - ActionListener listener + ComputeListener computeListener ) { final var exchangeSink = exchangeService.getSinkHandler(globalSessionId); parentTask.addListener( () -> exchangeService.finishSinkHandler(globalSessionId, new TaskCancelledException(parentTask.getReasonCancelled())) ); - ThreadPool threadPool = transportService.getThreadPool(); - final var responseHeadersCollector = new ResponseHeadersCollector(threadPool.getThreadContext()); - listener = ActionListener.runBefore(listener, responseHeadersCollector::finish); - final AtomicBoolean cancelled = new AtomicBoolean(); - final List collectedProfiles = configuration.profile() ? Collections.synchronizedList(new ArrayList<>()) : List.of(); final String localSessionId = clusterAlias + ":" + globalSessionId; var exchangeSource = new ExchangeSourceHandler( configuration.pragmas().exchangeBufferSize(), transportService.getThreadPool().executor(ThreadPool.Names.SEARCH) ); - try ( - Releasable ignored = exchangeSource.addEmptySink(); - RefCountingListener refs = new RefCountingListener(listener.map(unused -> new ComputeResponse(collectedProfiles))) - ) { - exchangeSink.addCompletionListener(refs.acquire()); - exchangeSource.addCompletionListener(refs.acquire()); + try (Releasable ignored = exchangeSource.addEmptySink()) { + exchangeSink.addCompletionListener(computeListener.acquireAvoid()); + exchangeSource.addCompletionListener(computeListener.acquireAvoid()); PhysicalPlan coordinatorPlan = new ExchangeSinkExec( plan.source(), plan.output(), @@ -860,13 +819,7 @@ void runComputeOnRemoteCluster( parentTask, new ComputeContext(localSessionId, clusterAlias, List.of(), configuration, exchangeSource, exchangeSink), coordinatorPlan, - cancelOnFailure(parentTask, cancelled, refs.acquire()).map(driverProfiles -> { - responseHeadersCollector.collect(); - if (configuration.profile()) { - collectedProfiles.addAll(driverProfiles); - } - return null; - }) + computeListener.acquireCompute() ); startComputeOnDataNodes( localSessionId, @@ -877,14 +830,7 @@ void runComputeOnRemoteCluster( concreteIndices, originalIndices, exchangeSource, - ActionListener.releaseAfter(refs.acquire(), exchangeSource.addEmptySink()), - () -> cancelOnFailure(parentTask, cancelled, refs.acquire()).map(r -> { - responseHeadersCollector.collect(); - if (configuration.profile()) { - collectedProfiles.addAll(r.getProfiles()); - } - return null; - }) + computeListener ); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java index 9328992120c08..5a6812c969757 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/TransportEsqlQueryAction.java @@ -37,7 +37,9 @@ import org.elasticsearch.xpack.esql.enrich.EnrichLookupService; import org.elasticsearch.xpack.esql.enrich.EnrichPolicyResolver; import org.elasticsearch.xpack.esql.execution.PlanExecutor; +import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.session.EsqlConfiguration; +import org.elasticsearch.xpack.esql.session.Result; import java.io.IOException; import java.time.ZoneOffset; @@ -45,6 +47,7 @@ import java.util.Locale; import java.util.Map; import java.util.concurrent.Executor; +import java.util.function.BiConsumer; import static org.elasticsearch.xpack.core.ClientHelper.ASYNC_SEARCH_ORIGIN; @@ -157,37 +160,37 @@ private void innerExecute(Task task, EsqlQueryRequest request, ActionListener> runPhase = (physicalPlan, resultListener) -> computeService.execute( + sessionId, + (CancellableTask) task, + physicalPlan, + configuration, + resultListener + ); + planExecutor.esql( request, sessionId, configuration, enrichPolicyResolver, - listener.delegateFailureAndWrap( - (delegate, physicalPlan) -> computeService.execute( - sessionId, - (CancellableTask) task, - physicalPlan, - configuration, - delegate.map(result -> { - List columns = physicalPlan.output() - .stream() - .map(c -> new ColumnInfoImpl(c.qualifiedName(), c.dataType().outputType())) - .toList(); - EsqlQueryResponse.Profile profile = configuration.profile() - ? new EsqlQueryResponse.Profile(result.profiles()) - : null; - if (task instanceof EsqlQueryTask asyncTask && request.keepOnCompletion()) { - String id = asyncTask.getExecutionId().getEncoded(); - return new EsqlQueryResponse(columns, result.pages(), profile, request.columnar(), id, false, request.async()); - } else { - return new EsqlQueryResponse(columns, result.pages(), profile, request.columnar(), request.async()); - } - }) - ) - ) + runPhase, + listener.map(result -> toResponse(task, request, configuration, result)) ); } + private EsqlQueryResponse toResponse(Task task, EsqlQueryRequest request, EsqlConfiguration configuration, Result result) { + List columns = result.schema() + .stream() + .map(c -> new ColumnInfoImpl(c.qualifiedName(), c.dataType().outputType())) + .toList(); + EsqlQueryResponse.Profile profile = configuration.profile() ? new EsqlQueryResponse.Profile(result.profiles()) : null; + if (task instanceof EsqlQueryTask asyncTask && request.keepOnCompletion()) { + String id = asyncTask.getExecutionId().getEncoded(); + return new EsqlQueryResponse(columns, result.pages(), profile, request.columnar(), id, false, request.async()); + } + return new EsqlQueryResponse(columns, result.pages(), profile, request.columnar(), request.async()); + } + /** * Returns the ID for this compute session. The ID is unique within the cluster, and is used * to identify the compute-session across nodes. The ID is just the TaskID of the task that diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java index 3119b328e8074..8c831cc260e03 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/EsqlSession.java @@ -31,7 +31,6 @@ import org.elasticsearch.xpack.esql.core.index.IndexResolution; import org.elasticsearch.xpack.esql.core.index.MappingException; import org.elasticsearch.xpack.esql.core.plan.TableIdentifier; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.type.InvalidMappedField; import org.elasticsearch.xpack.esql.core.util.Holder; import org.elasticsearch.xpack.esql.enrich.EnrichPolicyResolver; @@ -46,6 +45,7 @@ import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.Keep; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.RegexExtract; import org.elasticsearch.xpack.esql.plan.physical.EstimatesRowSize; @@ -58,12 +58,12 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.function.BiConsumer; import java.util.function.BiFunction; import java.util.function.Predicate; import java.util.stream.Collectors; import static org.elasticsearch.index.query.QueryBuilders.boolQuery; -import static org.elasticsearch.xpack.esql.core.util.ActionListeners.map; import static org.elasticsearch.xpack.esql.core.util.StringUtils.WILDCARD; public class EsqlSession { @@ -110,26 +110,31 @@ public String sessionId() { return sessionId; } - public void execute(EsqlQueryRequest request, ActionListener listener) { + /** + * Execute an ESQL request. + */ + public void execute( + EsqlQueryRequest request, + BiConsumer> runPhase, + ActionListener listener + ) { LOGGER.debug("ESQL query:\n{}", request.query()); - optimizedPhysicalPlan( + analyzedPlan( parse(request.query(), request.params()), - listener.map(plan -> EstimatesRowSize.estimateRowSize(0, plan.transformUp(FragmentExec.class, f -> { - QueryBuilder filter = request.filter(); - if (filter != null) { - var fragmentFilter = f.esFilter(); - // TODO: have an ESFilter and push down to EsQueryExec / EsSource - // This is an ugly hack to push the filter parameter to Lucene - // TODO: filter integration testing - filter = fragmentFilter != null ? boolQuery().filter(fragmentFilter).must(filter) : filter; - LOGGER.debug("Fold filter {} to EsQueryExec", filter); - f = f.withFilter(filter); - } - return f; - }))) + listener.delegateFailureAndWrap((next, analyzedPlan) -> executeAnalyzedPlan(request, runPhase, analyzedPlan, next)) ); } + public void executeAnalyzedPlan( + EsqlQueryRequest request, + BiConsumer> runPhase, + LogicalPlan analyzedPlan, + ActionListener listener + ) { + // TODO phased execution lands here. + runPhase.accept(logicalPlanToPhysicalPlan(analyzedPlan, request), listener); + } + private LogicalPlan parse(String query, QueryParams params) { var parsed = new EsqlParser().createStatement(query, params); LOGGER.debug("Parsed logical plan:\n{}", parsed); @@ -145,6 +150,7 @@ public void analyzedPlan(LogicalPlan parsed, ActionListener listene preAnalyze(parsed, (indices, policies) -> { Analyzer analyzer = new Analyzer(new AnalyzerContext(configuration, functionRegistry, indices, policies), verifier); var plan = analyzer.analyze(parsed); + plan.setAnalyzed(); LOGGER.debug("Analyzed plan:\n{}", plan); return plan; }, listener); @@ -305,28 +311,41 @@ private static Set subfields(Set names) { return names.stream().filter(name -> name.endsWith(WILDCARD) == false).map(name -> name + ".*").collect(Collectors.toSet()); } - public void optimizedPlan(LogicalPlan logicalPlan, ActionListener listener) { - analyzedPlan(logicalPlan, map(listener, p -> { - var plan = logicalPlanOptimizer.optimize(p); - LOGGER.debug("Optimized logicalPlan plan:\n{}", plan); - return plan; - })); + private PhysicalPlan logicalPlanToPhysicalPlan(LogicalPlan logicalPlan, EsqlQueryRequest request) { + PhysicalPlan physicalPlan = optimizedPhysicalPlan(logicalPlan); + physicalPlan = physicalPlan.transformUp(FragmentExec.class, f -> { + QueryBuilder filter = request.filter(); + if (filter != null) { + var fragmentFilter = f.esFilter(); + // TODO: have an ESFilter and push down to EsQueryExec / EsSource + // This is an ugly hack to push the filter parameter to Lucene + // TODO: filter integration testing + filter = fragmentFilter != null ? boolQuery().filter(fragmentFilter).must(filter) : filter; + LOGGER.debug("Fold filter {} to EsQueryExec", filter); + f = f.withFilter(filter); + } + return f; + }); + return EstimatesRowSize.estimateRowSize(0, physicalPlan); } - public void physicalPlan(LogicalPlan optimized, ActionListener listener) { - optimizedPlan(optimized, map(listener, p -> { - var plan = mapper.map(p); - LOGGER.debug("Physical plan:\n{}", plan); - return plan; - })); + public LogicalPlan optimizedPlan(LogicalPlan logicalPlan) { + assert logicalPlan.analyzed(); + var plan = logicalPlanOptimizer.optimize(logicalPlan); + LOGGER.debug("Optimized logicalPlan plan:\n{}", plan); + return plan; } - public void optimizedPhysicalPlan(LogicalPlan logicalPlan, ActionListener listener) { - physicalPlan(logicalPlan, map(listener, p -> { - var plan = physicalPlanOptimizer.optimize(p); - LOGGER.debug("Optimized physical plan:\n{}", plan); - return plan; - })); + public PhysicalPlan physicalPlan(LogicalPlan logicalPlan) { + var plan = mapper.map(optimizedPlan(logicalPlan)); + LOGGER.debug("Physical plan:\n{}", plan); + return plan; + } + + public PhysicalPlan optimizedPhysicalPlan(LogicalPlan logicalPlan) { + var plan = physicalPlanOptimizer.optimize(physicalPlan(logicalPlan)); + LOGGER.debug("Optimized physical plan:\n{}", plan); + return plan; } public static InvalidMappedField specificValidity(String fieldName, Map types) { diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Result.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Result.java index 7cbf3987af2cb..42beb88bbe38b 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Result.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/session/Result.java @@ -7,8 +7,23 @@ package org.elasticsearch.xpack.esql.session; +import org.elasticsearch.compute.data.Block; +import org.elasticsearch.compute.data.Page; +import org.elasticsearch.compute.operator.DriverProfile; import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.util.List; -public record Result(List columns, List> values) {} +/** + * Results from running a chunk of ESQL. + * @param schema "Schema" of the {@link Attribute}s that are produced by the {@link LogicalPlan} + * that was run. Each {@link Page} contains a {@link Block} of values for each + * attribute in this list. + * @param pages Actual values produced by running the ESQL. + * @param profiles {@link DriverProfile}s from all drivers that ran to produce the output. These + * are quite cheap to build, so we build them for all ESQL runs, regardless of if + * users have asked for them. But we only include them in the results if users ask + * for them. + */ +public record Result(List schema, List pages, List profiles) {} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/stats/FeatureMetric.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/stats/FeatureMetric.java index d5c4a67b01e8b..c4d890a818ec7 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/stats/FeatureMetric.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/stats/FeatureMetric.java @@ -7,18 +7,18 @@ package org.elasticsearch.xpack.esql.stats; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Dissect; import org.elasticsearch.xpack.esql.plan.logical.Drop; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.Grok; import org.elasticsearch.xpack.esql.plan.logical.Keep; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.MvExpand; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Rename; import org.elasticsearch.xpack.esql.plan.logical.Row; import org.elasticsearch.xpack.esql.plan.logical.meta.MetaFunctions; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java index b63a24556c31f..20b4d3a503f0c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/CsvTests.java @@ -31,7 +31,6 @@ import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler; import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler; import org.elasticsearch.core.Releasables; -import org.elasticsearch.core.TimeValue; import org.elasticsearch.core.Tuple; import org.elasticsearch.logging.LogManager; import org.elasticsearch.logging.Logger; @@ -48,17 +47,16 @@ import org.elasticsearch.xpack.esql.CsvTestUtils.ActualResults; import org.elasticsearch.xpack.esql.CsvTestUtils.Type; import org.elasticsearch.xpack.esql.action.EsqlCapabilities; +import org.elasticsearch.xpack.esql.action.EsqlQueryRequest; import org.elasticsearch.xpack.esql.analysis.Analyzer; import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; import org.elasticsearch.xpack.esql.analysis.EnrichResolution; import org.elasticsearch.xpack.esql.analysis.PreAnalyzer; import org.elasticsearch.xpack.esql.core.CsvSpecReader; import org.elasticsearch.xpack.esql.core.SpecReader; -import org.elasticsearch.xpack.esql.core.expression.Expressions; +import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.index.EsIndex; import org.elasticsearch.xpack.esql.core.index.IndexResolution; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.enrich.EnrichLookupService; import org.elasticsearch.xpack.esql.enrich.ResolvedEnrichPolicy; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; @@ -73,7 +71,7 @@ import org.elasticsearch.xpack.esql.optimizer.TestPhysicalPlanOptimizer; import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.elasticsearch.xpack.esql.plan.logical.Enrich; -import org.elasticsearch.xpack.esql.plan.physical.EstimatesRowSize; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.physical.LocalSourceExec; import org.elasticsearch.xpack.esql.plan.physical.OutputExec; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; @@ -85,6 +83,8 @@ import org.elasticsearch.xpack.esql.plugin.EsqlFeatures; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.elasticsearch.xpack.esql.session.EsqlConfiguration; +import org.elasticsearch.xpack.esql.session.EsqlSession; +import org.elasticsearch.xpack.esql.session.Result; import org.elasticsearch.xpack.esql.stats.DisabledSearchStats; import org.junit.After; import org.junit.Before; @@ -100,6 +100,7 @@ import java.util.TreeMap; import java.util.concurrent.Executor; import java.util.concurrent.TimeUnit; +import java.util.function.BiConsumer; import static org.elasticsearch.xpack.esql.CsvTestUtils.ExpectedResults; import static org.elasticsearch.xpack.esql.CsvTestUtils.isEnabled; @@ -330,16 +331,14 @@ private static EnrichPolicy loadEnrichPolicyMapping(String policyFileName) { } } - private PhysicalPlan physicalPlan(LogicalPlan parsed, CsvTestsDataLoader.TestsDataset dataset) { + private LogicalPlan analyzedPlan(LogicalPlan parsed, CsvTestsDataLoader.TestsDataset dataset) { var indexResolution = loadIndexResolution(dataset.mappingFileName(), dataset.indexName()); var enrichPolicies = loadEnrichPolicies(); var analyzer = new Analyzer(new AnalyzerContext(configuration, functionRegistry, indexResolution, enrichPolicies), TEST_VERIFIER); - var analyzed = analyzer.analyze(parsed); - var logicalOptimized = new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration)).optimize(analyzed); - var physicalPlan = mapper.map(logicalOptimized); - var optimizedPlan = EstimatesRowSize.estimateRowSize(0, physicalPlanOptimizer.optimize(physicalPlan)); - opportunisticallyAssertPlanSerialization(physicalPlan, optimizedPlan); // comment out to disable serialization - return optimizedPlan; + LogicalPlan plan = analyzer.analyze(parsed); + plan.setAnalyzed(); + LOGGER.debug("Analyzed plan:\n{}", plan); + return plan; } private static CsvTestsDataLoader.TestsDataset testsDataset(LogicalPlan parsed) { @@ -381,90 +380,43 @@ private static TestPhysicalOperationProviders testOperationProviders(CsvTestsDat } private ActualResults executePlan(BigArrays bigArrays) throws Exception { - var parsed = parser.createStatement(testCase.query); + LogicalPlan parsed = parser.createStatement(testCase.query); var testDataset = testsDataset(parsed); + LogicalPlan analyzed = analyzedPlan(parsed, testDataset); - String sessionId = "csv-test"; - BlockFactory blockFactory = new BlockFactory( - bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST), - bigArrays, - ByteSizeValue.ofBytes(randomLongBetween(1, BlockFactory.DEFAULT_MAX_BLOCK_PRIMITIVE_ARRAY_SIZE.getBytes() * 2)) - ); - ExchangeSourceHandler exchangeSource = new ExchangeSourceHandler(between(1, 64), executor); - ExchangeSinkHandler exchangeSink = new ExchangeSinkHandler(blockFactory, between(1, 64), threadPool::relativeTimeInMillis); - LocalExecutionPlanner executionPlanner = new LocalExecutionPlanner( - sessionId, - "", - new CancellableTask(1, "transport", "esql", null, TaskId.EMPTY_TASK_ID, Map.of()), - bigArrays, - blockFactory, - randomNodeSettings(), + EsqlSession session = new EsqlSession( + getTestName(), configuration, - exchangeSource, - exchangeSink, - Mockito.mock(EnrichLookupService.class), - testOperationProviders(testDataset) + null, + null, + null, + functionRegistry, + new LogicalPlanOptimizer(new LogicalOptimizerContext(configuration)), + mapper, + TEST_VERIFIER ); - // - // Keep in sync with ComputeService#execute - // - PhysicalPlan physicalPlan = physicalPlan(parsed, testDataset); - Tuple coordinatorAndDataNodePlan = PlannerUtils.breakPlanBetweenCoordinatorAndDataNode( - physicalPlan, - configuration + TestPhysicalOperationProviders physicalOperationProviders = testOperationProviders(testDataset); + + PlainActionFuture listener = new PlainActionFuture<>(); + + session.executeAnalyzedPlan( + new EsqlQueryRequest(), + runPhase(bigArrays, physicalOperationProviders), + analyzed, + listener.delegateFailureAndWrap( + // Wrap so we can capture the warnings in the calling thread + (next, result) -> next.onResponse( + new ActualResults( + result.schema().stream().map(Attribute::name).toList(), + result.schema().stream().map(a -> Type.asType(a.dataType().nameUpper())).toList(), + result.schema().stream().map(Attribute::dataType).toList(), + result.pages(), + threadPool.getThreadContext().getResponseHeaders() + ) + ) + ) ); - PhysicalPlan coordinatorPlan = coordinatorAndDataNodePlan.v1(); - PhysicalPlan dataNodePlan = coordinatorAndDataNodePlan.v2(); - - if (LOGGER.isTraceEnabled()) { - LOGGER.trace("Coordinator plan\n" + coordinatorPlan); - LOGGER.trace("DataNode plan\n" + dataNodePlan); - } - - List columnNames = Expressions.names(coordinatorPlan.output()); - List dataTypes = new ArrayList<>(columnNames.size()); - List columnTypes = coordinatorPlan.output() - .stream() - .peek(o -> dataTypes.add(o.dataType())) - .map(o -> Type.asType(o.dataType().nameUpper())) - .toList(); - - List drivers = new ArrayList<>(); - List collectedPages = Collections.synchronizedList(new ArrayList<>()); - - // replace fragment inside the coordinator plan - try { - LocalExecutionPlan coordinatorNodeExecutionPlan = executionPlanner.plan(new OutputExec(coordinatorPlan, collectedPages::add)); - drivers.addAll(coordinatorNodeExecutionPlan.createDrivers(sessionId)); - if (dataNodePlan != null) { - var searchStats = new DisabledSearchStats(); - var logicalTestOptimizer = new LocalLogicalPlanOptimizer(new LocalLogicalOptimizerContext(configuration, searchStats)); - var physicalTestOptimizer = new TestLocalPhysicalPlanOptimizer( - new LocalPhysicalOptimizerContext(configuration, searchStats) - ); - - var csvDataNodePhysicalPlan = PlannerUtils.localPlan(dataNodePlan, logicalTestOptimizer, physicalTestOptimizer); - exchangeSource.addRemoteSink(exchangeSink::fetchPageAsync, randomIntBetween(1, 3)); - LocalExecutionPlan dataNodeExecutionPlan = executionPlanner.plan(csvDataNodePhysicalPlan); - drivers.addAll(dataNodeExecutionPlan.createDrivers(sessionId)); - Randomness.shuffle(drivers); - } - // Execute the driver - DriverRunner runner = new DriverRunner(threadPool.getThreadContext()) { - @Override - protected void start(Driver driver, ActionListener driverListener) { - Driver.start(threadPool.getThreadContext(), executor, driver, between(1, 1000), driverListener); - } - }; - PlainActionFuture future = new PlainActionFuture<>(); - runner.runToCompletion(drivers, ActionListener.releaseAfter(future, () -> Releasables.close(drivers)).map(ignore -> { - var responseHeaders = threadPool.getThreadContext().getResponseHeaders(); - return new ActualResults(columnNames, columnTypes, dataTypes, collectedPages, responseHeaders); - })); - return future.actionGet(TimeValue.timeValueSeconds(30)); - } finally { - Releasables.close(() -> Releasables.close(drivers)); - } + return listener.get(); } private Settings randomNodeSettings() { @@ -487,17 +439,15 @@ private Throwable reworkException(Throwable th) { } // Asserts that the serialization and deserialization of the plan creates an equivalent plan. - private void opportunisticallyAssertPlanSerialization(PhysicalPlan... plans) { - for (var plan : plans) { - var tmp = plan; - do { - if (tmp instanceof LocalSourceExec) { - return; // skip plans with localSourceExec - } - } while (tmp.children().isEmpty() == false && (tmp = tmp.children().get(0)) != null); + private void opportunisticallyAssertPlanSerialization(PhysicalPlan plan) { + var tmp = plan; + do { + if (tmp instanceof LocalSourceExec) { + return; // skip plans with localSourceExec + } + } while (tmp.children().isEmpty() == false && (tmp = tmp.children().get(0)) != null); - SerializationTestUtils.assertSerialization(plan, configuration); - } + SerializationTestUtils.assertSerialization(plan, configuration); } private void assertWarnings(List warnings) { @@ -509,6 +459,84 @@ private void assertWarnings(List warnings) { normalized.add(normW); } } - EsqlTestUtils.assertWarnings(normalized, testCase.expectedWarnings(true), testCase.expectedWarningsRegex()); + EsqlTestUtils.assertWarnings(normalized, testCase.expectedWarnings(), testCase.expectedWarningsRegex()); + } + + BiConsumer> runPhase( + BigArrays bigArrays, + TestPhysicalOperationProviders physicalOperationProviders + ) { + return (physicalPlan, listener) -> runPhase(bigArrays, physicalOperationProviders, physicalPlan, listener); + } + + void runPhase( + BigArrays bigArrays, + TestPhysicalOperationProviders physicalOperationProviders, + PhysicalPlan physicalPlan, + ActionListener listener + ) { + // Keep in sync with ComputeService#execute + opportunisticallyAssertPlanSerialization(physicalPlan); + Tuple coordinatorAndDataNodePlan = PlannerUtils.breakPlanBetweenCoordinatorAndDataNode( + physicalPlan, + configuration + ); + PhysicalPlan coordinatorPlan = coordinatorAndDataNodePlan.v1(); + PhysicalPlan dataNodePlan = coordinatorAndDataNodePlan.v2(); + + if (LOGGER.isTraceEnabled()) { + LOGGER.trace("Coordinator plan\n" + coordinatorPlan); + LOGGER.trace("DataNode plan\n" + dataNodePlan); + } + + BlockFactory blockFactory = new BlockFactory( + bigArrays.breakerService().getBreaker(CircuitBreaker.REQUEST), + bigArrays, + ByteSizeValue.ofBytes(randomLongBetween(1, BlockFactory.DEFAULT_MAX_BLOCK_PRIMITIVE_ARRAY_SIZE.getBytes() * 2)) + ); + ExchangeSourceHandler exchangeSource = new ExchangeSourceHandler(between(1, 64), executor); + ExchangeSinkHandler exchangeSink = new ExchangeSinkHandler(blockFactory, between(1, 64), threadPool::relativeTimeInMillis); + + LocalExecutionPlanner executionPlanner = new LocalExecutionPlanner( + getTestName(), + "", + new CancellableTask(1, "transport", "esql", null, TaskId.EMPTY_TASK_ID, Map.of()), + bigArrays, + blockFactory, + randomNodeSettings(), + configuration, + exchangeSource, + exchangeSink, + Mockito.mock(EnrichLookupService.class), + physicalOperationProviders + ); + + List collectedPages = Collections.synchronizedList(new ArrayList<>()); + + // replace fragment inside the coordinator plan + List drivers = new ArrayList<>(); + LocalExecutionPlan coordinatorNodeExecutionPlan = executionPlanner.plan(new OutputExec(coordinatorPlan, collectedPages::add)); + drivers.addAll(coordinatorNodeExecutionPlan.createDrivers(getTestName())); + if (dataNodePlan != null) { + var searchStats = new DisabledSearchStats(); + var logicalTestOptimizer = new LocalLogicalPlanOptimizer(new LocalLogicalOptimizerContext(configuration, searchStats)); + var physicalTestOptimizer = new TestLocalPhysicalPlanOptimizer(new LocalPhysicalOptimizerContext(configuration, searchStats)); + + var csvDataNodePhysicalPlan = PlannerUtils.localPlan(dataNodePlan, logicalTestOptimizer, physicalTestOptimizer); + exchangeSource.addRemoteSink(exchangeSink::fetchPageAsync, randomIntBetween(1, 3)); + LocalExecutionPlan dataNodeExecutionPlan = executionPlanner.plan(csvDataNodePhysicalPlan); + + drivers.addAll(dataNodeExecutionPlan.createDrivers(getTestName())); + Randomness.shuffle(drivers); + } + // Execute the drivers + DriverRunner runner = new DriverRunner(threadPool.getThreadContext()) { + @Override + protected void start(Driver driver, ActionListener driverListener) { + Driver.start(threadPool.getThreadContext(), executor, driver, between(1, 1000), driverListener); + } + }; + listener = ActionListener.releaseAfter(listener, () -> Releasables.close(drivers)); + runner.runToCompletion(drivers, listener.map(ignore -> new Result(physicalPlan.output(), collectedPages, List.of()))); } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java index fd811a2f2e217..8c5a5a4b3ba3b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java @@ -26,7 +26,6 @@ import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; @@ -34,6 +33,7 @@ import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.querydsl.query.SingleValueQuery; import org.elasticsearch.xpack.esql.session.EsqlConfiguration; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java index c78baabcd03a7..7c5dc73fb62af 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTestUtils.java @@ -11,11 +11,11 @@ import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.core.index.EsIndex; import org.elasticsearch.xpack.esql.core.index.IndexResolution; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.enrich.ResolvedEnrichPolicy; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.elasticsearch.xpack.esql.plan.logical.Enrich; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.session.EsqlConfiguration; import java.util.ArrayList; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java index 1f2ec0c236ecf..d6cd4a5e84d49 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/AnalyzerTests.java @@ -32,10 +32,6 @@ import org.elasticsearch.xpack.esql.core.index.EsIndex; import org.elasticsearch.xpack.esql.core.index.IndexResolution; import org.elasticsearch.xpack.esql.core.plan.TableIdentifier; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.TypesTests; import org.elasticsearch.xpack.esql.enrich.ResolvedEnrichPolicy; @@ -49,7 +45,11 @@ import org.elasticsearch.xpack.esql.plan.logical.EsRelation; import org.elasticsearch.xpack.esql.plan.logical.EsqlUnresolvedRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Lookup; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Row; import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject; import org.elasticsearch.xpack.esql.plugin.EsqlPlugin; @@ -1832,13 +1832,13 @@ public void testUnsupportedTypesInStats() { found value [x] type [unsigned_long] line 2:20: argument of [count_distinct(x)] must be [any exact type except unsigned_long or counter types],\ found value [x] type [unsigned_long] - line 2:39: argument of [max(x)] must be [datetime or numeric except unsigned_long or counter types],\ + line 2:39: argument of [max(x)] must be [boolean, datetime or numeric except unsigned_long or counter types],\ found value [max(x)] type [unsigned_long] line 2:47: argument of [median(x)] must be [numeric except unsigned_long or counter types],\ found value [x] type [unsigned_long] line 2:58: argument of [median_absolute_deviation(x)] must be [numeric except unsigned_long or counter types],\ found value [x] type [unsigned_long] - line 2:88: argument of [min(x)] must be [datetime or numeric except unsigned_long or counter types],\ + line 2:88: argument of [min(x)] must be [boolean, datetime or numeric except unsigned_long or counter types],\ found value [min(x)] type [unsigned_long] line 2:96: first argument of [percentile(x, 10)] must be [numeric except unsigned_long],\ found value [x] type [unsigned_long] @@ -1852,13 +1852,13 @@ public void testUnsupportedTypesInStats() { Found 7 problems line 2:10: argument of [avg(x)] must be [numeric except unsigned_long or counter types],\ found value [x] type [version] - line 2:18: argument of [max(x)] must be [datetime or numeric except unsigned_long or counter types],\ + line 2:18: argument of [max(x)] must be [boolean, datetime or numeric except unsigned_long or counter types],\ found value [max(x)] type [version] line 2:26: argument of [median(x)] must be [numeric except unsigned_long or counter types],\ found value [x] type [version] line 2:37: argument of [median_absolute_deviation(x)] must be [numeric except unsigned_long or counter types],\ found value [x] type [version] - line 2:67: argument of [min(x)] must be [datetime or numeric except unsigned_long or counter types],\ + line 2:67: argument of [min(x)] must be [boolean, datetime or numeric except unsigned_long or counter types],\ found value [min(x)] type [version] line 2:75: first argument of [percentile(x, 10)] must be [numeric except unsigned_long], found value [x] type [version] line 2:94: argument of [sum(x)] must be [numeric except unsigned_long or counter types], found value [x] type [version]"""); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/ParsingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/ParsingTests.java index 8dfd8eee58c24..0231dc1f4a82b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/ParsingTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/ParsingTests.java @@ -15,12 +15,12 @@ import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.index.EsIndex; import org.elasticsearch.xpack.esql.core.index.IndexResolution; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.TypesTests; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.expression.function.FunctionDefinition; import org.elasticsearch.xpack.esql.parser.EsqlParser; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Row; import org.elasticsearch.xpack.esql.type.EsqlDataTypeConverter; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java index ad08130c5b0d9..00d12240e67e5 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/analysis/VerifierTests.java @@ -493,7 +493,8 @@ public void testAggregateOnCounter() { assertThat( error("FROM tests | STATS min(network.bytes_in)", tsdb), equalTo( - "1:20: argument of [min(network.bytes_in)] must be [datetime or numeric except unsigned_long or counter types]," + "1:20: argument of [min(network.bytes_in)] must be" + + " [boolean, datetime or numeric except unsigned_long or counter types]," + " found value [min(network.bytes_in)] type [counter_long]" ) ); @@ -501,7 +502,8 @@ public void testAggregateOnCounter() { assertThat( error("FROM tests | STATS max(network.bytes_in)", tsdb), equalTo( - "1:20: argument of [max(network.bytes_in)] must be [datetime or numeric except unsigned_long or counter types]," + "1:20: argument of [max(network.bytes_in)] must be" + + " [boolean, datetime or numeric except unsigned_long or counter types]," + " found value [max(network.bytes_in)] type [counter_long]" ) ); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java index 05a6cec51284f..792c6b5139796 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAggregationTestCase.java @@ -15,6 +15,8 @@ import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.Releasables; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.NumericUtils; import org.elasticsearch.xpack.esql.expression.SurrogateExpression; @@ -23,7 +25,10 @@ import org.elasticsearch.xpack.esql.planner.PlannerUtils; import org.elasticsearch.xpack.esql.planner.ToAggregator; +import java.util.ArrayList; +import java.util.HashSet; import java.util.List; +import java.util.Set; import java.util.function.Consumer; import java.util.stream.Collectors; import java.util.stream.IntStream; @@ -31,8 +36,11 @@ import static org.elasticsearch.compute.data.BlockUtils.toJavaObject; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.oneOf; /** * Base class for aggregation tests. @@ -47,7 +55,43 @@ public abstract class AbstractAggregationTestCase extends AbstractFunctionTestCa */ protected static Iterable parameterSuppliersFromTypedDataWithDefaultChecks(List suppliers) { // TODO: Add case with no input expecting null - return parameterSuppliersFromTypedData(randomizeBytesRefsOffset(suppliers)); + return parameterSuppliersFromTypedData(withNoRowsExpectingNull(randomizeBytesRefsOffset(suppliers))); + } + + /** + * Adds a test case with no rows, expecting null, to the list of suppliers. + */ + protected static List withNoRowsExpectingNull(List suppliers) { + List newSuppliers = new ArrayList<>(suppliers); + Set> uniqueSignatures = new HashSet<>(); + + for (TestCaseSupplier original : suppliers) { + if (uniqueSignatures.add(original.types())) { + newSuppliers.add(new TestCaseSupplier(original.name() + " with no rows", original.types(), () -> { + var testCase = original.get(); + + if (testCase.getData().stream().noneMatch(TestCaseSupplier.TypedData::isMultiRow)) { + // Fail if no multi-row data, at least until a real case is found + fail("No multi-row data found in test case: " + testCase); + } + + var newData = testCase.getData().stream().map(td -> td.isMultiRow() ? td.withData(List.of()) : td).toList(); + + return new TestCaseSupplier.TestCase( + newData, + testCase.evaluatorToString(), + testCase.expectedType(), + nullValue(), + null, + testCase.getExpectedTypeError(), + null, + null + ); + })); + } + } + + return newSuppliers; } public void testAggregate() { @@ -56,6 +100,12 @@ public void testAggregate() { resolveExpression(expression, this::aggregateSingleMode, this::evaluate); } + public void testAggregateIntermediate() { + Expression expression = randomBoolean() ? buildDeepCopyOfFieldExpression(testCase) : buildFieldExpression(testCase); + + resolveExpression(expression, this::aggregateWithIntermediates, this::evaluate); + } + public void testFold() { Expression expression = buildLiteralExpression(testCase); @@ -80,17 +130,82 @@ public void testFold() { }); } - private void aggregateSingleMode(AggregatorFunctionSupplier aggregatorFunctionSupplier) { + private void aggregateSingleMode(Expression expression) { Object result; - try (var aggregator = new Aggregator(aggregatorFunctionSupplier.aggregator(driverContext()), AggregatorMode.SINGLE)) { - Page inputPage = rows(testCase.getMultiRowDataValues()); + try (var aggregator = aggregator(expression, initialInputChannels(), AggregatorMode.SINGLE)) { + for (Page inputPage : rows(testCase.getMultiRowFields())) { + try { + aggregator.processPage(inputPage); + } finally { + inputPage.releaseBlocks(); + } + } + + result = extractResultFromAggregator(aggregator, PlannerUtils.toElementType(testCase.expectedType())); + } + + assertThat(result, not(equalTo(Double.NaN))); + assert testCase.getMatcher().matches(Double.POSITIVE_INFINITY) == false; + assertThat(result, not(equalTo(Double.POSITIVE_INFINITY))); + assert testCase.getMatcher().matches(Double.NEGATIVE_INFINITY) == false; + assertThat(result, not(equalTo(Double.NEGATIVE_INFINITY))); + assertThat(result, testCase.getMatcher()); + if (testCase.getExpectedWarnings() != null) { + assertWarnings(testCase.getExpectedWarnings()); + } + } + + private void aggregateWithIntermediates(Expression expression) { + int intermediateBlockOffset = randomIntBetween(0, 10); + Block[] intermediateBlocks; + int intermediateStates; + + // Input rows to intermediate states + try (var aggregator = aggregator(expression, initialInputChannels(), AggregatorMode.INITIAL)) { + intermediateStates = aggregator.evaluateBlockCount(); + + int intermediateBlockExtraSize = randomIntBetween(0, 10); + intermediateBlocks = new Block[intermediateBlockOffset + intermediateStates + intermediateBlockExtraSize]; + + for (Page inputPage : rows(testCase.getMultiRowFields())) { + try { + aggregator.processPage(inputPage); + } finally { + inputPage.releaseBlocks(); + } + } + + aggregator.evaluate(intermediateBlocks, intermediateBlockOffset, driverContext()); + + int positionCount = intermediateBlocks[intermediateBlockOffset].getPositionCount(); + + // Fill offset and extra blocks with nulls + for (int i = 0; i < intermediateBlockOffset; i++) { + intermediateBlocks[i] = driverContext().blockFactory().newConstantNullBlock(positionCount); + } + for (int i = intermediateBlockOffset + intermediateStates; i < intermediateBlocks.length; i++) { + intermediateBlocks[i] = driverContext().blockFactory().newConstantNullBlock(positionCount); + } + } + + Object result; + // Intermediate states to final result + try ( + var aggregator = aggregator( + expression, + intermediaryInputChannels(intermediateStates, intermediateBlockOffset), + AggregatorMode.FINAL + ) + ) { + Page inputPage = new Page(intermediateBlocks); try { - aggregator.processPage(inputPage); + if (inputPage.getPositionCount() > 0) { + aggregator.processPage(inputPage); + } } finally { inputPage.releaseBlocks(); } - // ElementType from DataType result = extractResultFromAggregator(aggregator, PlannerUtils.toElementType(testCase.expectedType())); } @@ -124,11 +239,7 @@ private void evaluate(Expression evaluableExpression) { } } - private void resolveExpression( - Expression expression, - Consumer onAggregator, - Consumer onEvaluableExpression - ) { + private void resolveExpression(Expression expression, Consumer onAggregator, Consumer onEvaluableExpression) { logger.info( "Test Values: " + testCase.getData().stream().map(TestCaseSupplier.TypedData::toString).collect(Collectors.joining(",")) ); @@ -146,6 +257,13 @@ private void resolveExpression( expression = new FoldNull().rule(expression); assertThat(expression.dataType(), equalTo(testCase.expectedType())); + assumeTrue( + "Surrogate expression with non-trivial children cannot be evaluated", + expression.children() + .stream() + .allMatch(child -> child instanceof FieldAttribute || child instanceof DeepCopy || child instanceof Literal) + ); + if (expression instanceof AggregateFunction == false) { onEvaluableExpression.accept(expression); return; @@ -154,8 +272,7 @@ private void resolveExpression( assertThat(expression, instanceOf(ToAggregator.class)); logger.info("Result type: " + expression.dataType()); - var inputChannels = inputChannels(); - onAggregator.accept(((ToAggregator) expression).supplier(inputChannels)); + onAggregator.accept(expression); } private Object extractResultFromAggregator(Aggregator aggregator, ElementType expectedElementType) { @@ -167,7 +284,8 @@ private Object extractResultFromAggregator(Aggregator aggregator, ElementType ex var block = blocks[resultBlockIndex]; - assertThat(block.elementType(), equalTo(expectedElementType)); + // For null blocks, the element type is NULL, so if the provided matcher matches, the type works too + assertThat(block.elementType(), is(oneOf(expectedElementType, ElementType.NULL))); return toJavaObject(blocks[resultBlockIndex], 0); } finally { @@ -175,10 +293,14 @@ private Object extractResultFromAggregator(Aggregator aggregator, ElementType ex } } - private List inputChannels() { + private List initialInputChannels() { // TODO: Randomize channels // TODO: If surrogated, channels may change - return IntStream.range(0, testCase.getMultiRowDataValues().size()).boxed().toList(); + return IntStream.range(0, testCase.getMultiRowFields().size()).boxed().toList(); + } + + private List intermediaryInputChannels(int intermediaryStates, int offset) { + return IntStream.range(offset, offset + intermediaryStates).boxed().toList(); } /** @@ -210,4 +332,10 @@ private Expression resolveSurrogates(Expression expression) { return expression; } + + private Aggregator aggregator(Expression expression, List inputChannels, AggregatorMode mode) { + AggregatorFunctionSupplier aggregatorFunctionSupplier = ((ToAggregator) expression).supplier(inputChannels); + + return new Aggregator(aggregatorFunctionSupplier.aggregator(driverContext()), mode); + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java index dc650e3fcd965..80dc2e434ab0f 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractFunctionTestCase.java @@ -49,6 +49,7 @@ import org.elasticsearch.xpack.esql.optimizer.FoldNull; import org.elasticsearch.xpack.esql.parser.ExpressionBuilder; import org.elasticsearch.xpack.esql.planner.Layout; +import org.elasticsearch.xpack.esql.planner.PlannerUtils; import org.elasticsearch.xpack.versionfield.Version; import org.junit.After; import org.junit.AfterClass; @@ -214,26 +215,55 @@ protected final Page row(List values) { } /** - * Creates a page based on a list of lists, where each list represents a column. + * Creates a list of pages based on a list of multi-row fields. */ - protected final Page rows(List> values) { - if (values.isEmpty()) { - return new Page(0, BlockUtils.NO_BLOCKS); + protected final List rows(List multirowFields) { + if (multirowFields.isEmpty()) { + return List.of(); } - var rowsCount = values.get(0).size(); + var rowsCount = multirowFields.get(0).multiRowData().size(); - values.stream().skip(1).forEach(l -> assertThat("All multi-row fields must have the same number of rows", l, hasSize(rowsCount))); + multirowFields.stream() + .skip(1) + .forEach( + field -> assertThat("All multi-row fields must have the same number of rows", field.multiRowData(), hasSize(rowsCount)) + ); - var rows = new ArrayList>(); - for (int i = 0; i < rowsCount; i++) { - final int index = i; - rows.add(values.stream().map(l -> l.get(index)).toList()); - } + List pages = new ArrayList<>(); + + int pageSize = randomIntBetween(1, 100); + for (int initialRow = 0; initialRow < rowsCount;) { + if (pageSize > rowsCount - initialRow) { + pageSize = rowsCount - initialRow; + } + + var blocks = new Block[multirowFields.size()]; + + for (int i = 0; i < multirowFields.size(); i++) { + var field = multirowFields.get(i); + try ( + var wrapper = BlockUtils.wrapperFor( + TestBlockFactory.getNonBreakingInstance(), + PlannerUtils.toElementType(field.type()), + pageSize + ) + ) { + var multiRowData = field.multiRowData(); + for (int row = initialRow; row < initialRow + pageSize; row++) { + wrapper.accept(multiRowData.get(row)); + } + + blocks[i] = wrapper.builder().build(); + } + } - var blocks = BlockUtils.fromList(TestBlockFactory.getNonBreakingInstance(), rows); + pages.add(new Page(pageSize, blocks)); + initialRow += pageSize; + pageSize = randomIntBetween(1, 100); + } - return new Page(rowsCount, blocks); + return pages; } /** diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/MultiRowTestCaseSupplier.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/MultiRowTestCaseSupplier.java new file mode 100644 index 0000000000000..68f5414302c9d --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/MultiRowTestCaseSupplier.java @@ -0,0 +1,303 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.type.DataType; + +import java.util.ArrayList; +import java.util.List; + +import static org.elasticsearch.test.ESTestCase.randomBoolean; +import static org.elasticsearch.test.ESTestCase.randomList; +import static org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier.TypedDataSupplier; + +/** + * Extension of {@link TestCaseSupplier} that provided multi-row test cases. + */ +public final class MultiRowTestCaseSupplier { + + private MultiRowTestCaseSupplier() {} + + public static List intCases(int minRows, int maxRows, int min, int max, boolean includeZero) { + List cases = new ArrayList<>(); + + if (0 <= max && 0 >= min && includeZero) { + cases.add(new TypedDataSupplier("<0 ints>", () -> randomList(minRows, maxRows, () -> 0), DataType.INTEGER, false, true)); + } + + if (max != 0) { + cases.add( + new TypedDataSupplier("<" + max + " ints>", () -> randomList(minRows, maxRows, () -> max), DataType.INTEGER, false, true) + ); + } + + if (min != 0 && min != max) { + cases.add( + new TypedDataSupplier("<" + min + " ints>", () -> randomList(minRows, maxRows, () -> min), DataType.INTEGER, false, true) + ); + } + + int lower = Math.max(min, 1); + int upper = Math.min(max, Integer.MAX_VALUE); + if (lower < upper) { + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, () -> ESTestCase.randomIntBetween(lower, upper)), + DataType.INTEGER, + false, + true + ) + ); + } + + int lower1 = Math.max(min, Integer.MIN_VALUE); + int upper1 = Math.min(max, -1); + if (lower1 < upper1) { + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, () -> ESTestCase.randomIntBetween(lower1, upper1)), + DataType.INTEGER, + false, + true + ) + ); + } + + if (min < 0 && max > 0) { + cases.add(new TypedDataSupplier("", () -> randomList(minRows, maxRows, () -> { + if (includeZero) { + return ESTestCase.randomIntBetween(min, max); + } + return randomBoolean() ? ESTestCase.randomIntBetween(min, -1) : ESTestCase.randomIntBetween(1, max); + }), DataType.INTEGER, false, true)); + } + + return cases; + } + + public static List longCases(int minRows, int maxRows, long min, long max, boolean includeZero) { + List cases = new ArrayList<>(); + + if (0 <= max && 0 >= min && includeZero) { + cases.add(new TypedDataSupplier("<0 longs>", () -> randomList(minRows, maxRows, () -> 0L), DataType.LONG, false, true)); + } + + if (max != 0) { + cases.add( + new TypedDataSupplier("<" + max + " longs>", () -> randomList(minRows, maxRows, () -> max), DataType.LONG, false, true) + ); + } + + if (min != 0 && min != max) { + cases.add( + new TypedDataSupplier("<" + min + " longs>", () -> randomList(minRows, maxRows, () -> min), DataType.LONG, false, true) + ); + } + + long lower = Math.max(min, 1); + long upper = Math.min(max, Long.MAX_VALUE); + if (lower < upper) { + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, () -> ESTestCase.randomLongBetween(lower, upper)), + DataType.LONG, + false, + true + ) + ); + } + + long lower1 = Math.max(min, Long.MIN_VALUE); + long upper1 = Math.min(max, -1); + if (lower1 < upper1) { + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, () -> ESTestCase.randomLongBetween(lower1, upper1)), + DataType.LONG, + false, + true + ) + ); + } + + if (min < 0 && max > 0) { + cases.add(new TypedDataSupplier("", () -> randomList(minRows, maxRows, () -> { + if (includeZero) { + return ESTestCase.randomLongBetween(min, max); + } + return randomBoolean() ? ESTestCase.randomLongBetween(min, -1) : ESTestCase.randomLongBetween(1, max); + }), DataType.LONG, false, true)); + } + + return cases; + } + + public static List doubleCases(int minRows, int maxRows, double min, double max, boolean includeZero) { + List cases = new ArrayList<>(); + + if (0d <= max && 0d >= min && includeZero) { + cases.add(new TypedDataSupplier("<0 doubles>", () -> randomList(minRows, maxRows, () -> 0d), DataType.DOUBLE, false, true)); + cases.add(new TypedDataSupplier("<-0 doubles>", () -> randomList(minRows, maxRows, () -> -0d), DataType.DOUBLE, false, true)); + } + + if (max != 0d) { + cases.add( + new TypedDataSupplier("<" + max + " doubles>", () -> randomList(minRows, maxRows, () -> max), DataType.DOUBLE, false, true) + ); + } + + if (min != 0d && min != max) { + cases.add( + new TypedDataSupplier("<" + min + " doubles>", () -> randomList(minRows, maxRows, () -> min), DataType.DOUBLE, false, true) + ); + } + + double lower1 = Math.max(min, 0d); + double upper1 = Math.min(max, 1d); + if (lower1 < upper1) { + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, () -> ESTestCase.randomDoubleBetween(lower1, upper1, true)), + DataType.DOUBLE, + false, + true + ) + ); + } + + double lower2 = Math.max(min, -1d); + double upper2 = Math.min(max, 0d); + if (lower2 < upper2) { + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, () -> ESTestCase.randomDoubleBetween(lower2, upper2, true)), + DataType.DOUBLE, + false, + true + ) + ); + } + + double lower3 = Math.max(min, 1d); + double upper3 = Math.min(max, Double.MAX_VALUE); + if (lower3 < upper3) { + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, () -> ESTestCase.randomDoubleBetween(lower3, upper3, true)), + DataType.DOUBLE, + false, + true + ) + ); + } + + double lower4 = Math.max(min, -Double.MAX_VALUE); + double upper4 = Math.min(max, -1d); + if (lower4 < upper4) { + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, () -> ESTestCase.randomDoubleBetween(lower4, upper4, true)), + DataType.DOUBLE, + false, + true + ) + ); + } + + if (min < 0 && max > 0) { + cases.add(new TypedDataSupplier("", () -> randomList(minRows, maxRows, () -> { + if (includeZero) { + return ESTestCase.randomDoubleBetween(min, max, true); + } + return randomBoolean() ? ESTestCase.randomDoubleBetween(min, -1, true) : ESTestCase.randomDoubleBetween(1, max, true); + }), DataType.DOUBLE, false, true)); + } + + return cases; + } + + public static List dateCases(int minRows, int maxRows) { + List cases = new ArrayList<>(); + + cases.add( + new TypedDataSupplier( + "<1970-01-01T00:00:00Z dates>", + () -> randomList(minRows, maxRows, () -> 0L), + DataType.DATETIME, + false, + true + ) + ); + + cases.add( + new TypedDataSupplier( + "", + // 1970-01-01T00:00:00Z - 2286-11-20T17:46:40Z + () -> randomList(minRows, maxRows, () -> ESTestCase.randomLongBetween(0, 10 * (long) 10e11)), + DataType.DATETIME, + false, + true + ) + ); + + cases.add( + new TypedDataSupplier( + "", + // 2286-11-20T17:46:40Z - +292278994-08-17T07:12:55.807Z + () -> randomList(minRows, maxRows, () -> ESTestCase.randomLongBetween(10 * (long) 10e11, Long.MAX_VALUE)), + DataType.DATETIME, + false, + true + ) + ); + + cases.add( + new TypedDataSupplier( + "", + // very close to +292278994-08-17T07:12:55.807Z, the maximum supported millis since epoch + () -> randomList(minRows, maxRows, () -> ESTestCase.randomLongBetween(Long.MAX_VALUE / 100 * 99, Long.MAX_VALUE)), + DataType.DATETIME, + false, + true + ) + ); + + return cases; + } + + public static List booleanCases(int minRows, int maxRows) { + List cases = new ArrayList<>(); + + cases.add(new TypedDataSupplier("", () -> randomList(minRows, maxRows, () -> true), DataType.BOOLEAN, false, true)); + + cases.add( + new TypedDataSupplier("", () -> randomList(minRows, maxRows, () -> false), DataType.BOOLEAN, false, true) + ); + + cases.add( + new TypedDataSupplier( + "", + () -> randomList(minRows, maxRows, ESTestCase::randomBoolean), + DataType.BOOLEAN, + false, + true + ) + ); + + return cases; + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java index 9095f5da63bf3..6ece7151ccd7a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/TestCaseSupplier.java @@ -820,6 +820,12 @@ public static void unary( unary(suppliers, expectedEvaluatorToString, valueSuppliers, expectedOutputType, expected, unused -> warnings); } + /** + * Generate cases for {@link DataType#INTEGER}. + *

+ * For multi-row parameters, see {@link MultiRowTestCaseSupplier#intCases}. + *

+ */ public static List intCases(int min, int max, boolean includeZero) { List cases = new ArrayList<>(); if (0 <= max && 0 >= min && includeZero) { @@ -844,6 +850,12 @@ public static List intCases(int min, int max, boolean include return cases; } + /** + * Generate cases for {@link DataType#LONG}. + *

+ * For multi-row parameters, see {@link MultiRowTestCaseSupplier#longCases}. + *

+ */ public static List longCases(long min, long max, boolean includeZero) { List cases = new ArrayList<>(); if (0L <= max && 0L >= min && includeZero) { @@ -909,6 +921,12 @@ public static List ulongCases(BigInteger min, BigInteger max, return cases; } + /** + * Generate cases for {@link DataType#DOUBLE}. + *

+ * For multi-row parameters, see {@link MultiRowTestCaseSupplier#doubleCases}. + *

+ */ public static List doubleCases(double min, double max, boolean includeZero) { List cases = new ArrayList<>(); @@ -980,6 +998,12 @@ public static List booleanCases() { ); } + /** + * Generate cases for {@link DataType#DATETIME}. + *

+ * For multi-row parameters, see {@link MultiRowTestCaseSupplier#dateCases}. + *

+ */ public static List dateCases() { return List.of( new TypedDataSupplier("<1970-01-01T00:00:00Z>", () -> 0L, DataType.DATETIME), @@ -1301,8 +1325,8 @@ public List getDataValues() { return data.stream().filter(d -> d.forceLiteral == false).map(TypedData::data).collect(Collectors.toList()); } - public List> getMultiRowDataValues() { - return data.stream().filter(TypedData::isMultiRow).map(TypedData::multiRowData).collect(Collectors.toList()); + public List getMultiRowFields() { + return data.stream().filter(TypedData::isMultiRow).collect(Collectors.toList()); } public boolean canGetDataAsLiterals() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java new file mode 100644 index 0000000000000..f456bd409059a --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/AvgTests.java @@ -0,0 +1,95 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; + +public class AvgTests extends AbstractAggregationTestCase { + public AvgTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + var suppliers = new ArrayList(); + + Stream.of( + MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true), + MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true), + MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true) + ).flatMap(List::stream).map(AvgTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers)); + + suppliers.add( + // Folding + new TestCaseSupplier( + List.of(DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field")), + "Avg[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(200.) + ) + ) + ); + + return parameterSuppliersFromTypedDataWithDefaultChecks(suppliers); + } + + @Override + protected Expression build(Source source, List args) { + return new Avg(source, args.get(0)); + } + + private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { + return new TestCaseSupplier(List.of(fieldSupplier.type()), () -> { + var fieldTypedData = fieldSupplier.get(); + + Object expected = switch (fieldTypedData.type().widenSmallNumeric()) { + case INTEGER -> fieldTypedData.multiRowData() + .stream() + .map(v -> (Integer) v) + .collect(Collectors.summarizingInt(Integer::intValue)) + .getAverage(); + case LONG -> fieldTypedData.multiRowData() + .stream() + .map(v -> (Long) v) + .collect(Collectors.summarizingLong(Long::longValue)) + .getAverage(); + case DOUBLE -> fieldTypedData.multiRowData() + .stream() + .map(v -> (Double) v) + .collect(Collectors.summarizingDouble(Double::doubleValue)) + .getAverage(); + default -> throw new IllegalStateException("Unexpected value: " + fieldTypedData.type()); + }; + + return new TestCaseSupplier.TestCase( + List.of(fieldTypedData), + "Avg[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(expected) + ); + }); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxTests.java new file mode 100644 index 0000000000000..3fddaff226f3e --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MaxTests.java @@ -0,0 +1,170 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; + +public class MaxTests extends AbstractAggregationTestCase { + public MaxTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + var suppliers = new ArrayList(); + + Stream.of( + MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true), + MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true), + MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true), + MultiRowTestCaseSupplier.dateCases(1, 1000), + MultiRowTestCaseSupplier.booleanCases(1, 1000) + ).flatMap(List::stream).map(MaxTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers)); + + suppliers.addAll( + List.of( + // Surrogates + new TestCaseSupplier( + List.of(DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(5, 8, -2, 0, 200), DataType.INTEGER, "field")), + "Max[field=Attribute[channel=0]]", + DataType.INTEGER, + equalTo(200) + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, -2L, 0L, 200L), DataType.LONG, "field")), + "Max[field=Attribute[channel=0]]", + DataType.LONG, + equalTo(200L) + ) + ), + new TestCaseSupplier( + List.of(DataType.DOUBLE), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(5., 8., -2., 0., 200.), DataType.DOUBLE, "field")), + "Max[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(200.) + ) + ), + new TestCaseSupplier( + List.of(DataType.DATETIME), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.DATETIME, "field")), + "Max[field=Attribute[channel=0]]", + DataType.DATETIME, + equalTo(200L) + ) + ), + new TestCaseSupplier( + List.of(DataType.BOOLEAN), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(true, false, false, true), DataType.BOOLEAN, "field")), + "Max[field=Attribute[channel=0]]", + DataType.BOOLEAN, + equalTo(true) + ) + ), + + // Folding + new TestCaseSupplier( + List.of(DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field")), + "Max[field=Attribute[channel=0]]", + DataType.INTEGER, + equalTo(200) + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.LONG, "field")), + "Max[field=Attribute[channel=0]]", + DataType.LONG, + equalTo(200L) + ) + ), + new TestCaseSupplier( + List.of(DataType.DOUBLE), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200.), DataType.DOUBLE, "field")), + "Max[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(200.) + ) + ), + new TestCaseSupplier( + List.of(DataType.DATETIME), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.DATETIME, "field")), + "Max[field=Attribute[channel=0]]", + DataType.DATETIME, + equalTo(200L) + ) + ), + new TestCaseSupplier( + List.of(DataType.BOOLEAN), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(true), DataType.BOOLEAN, "field")), + "Max[field=Attribute[channel=0]]", + DataType.BOOLEAN, + equalTo(true) + ) + ) + ) + ); + + return parameterSuppliersFromTypedDataWithDefaultChecks(suppliers); + } + + @Override + protected Expression build(Source source, List args) { + return new Max(source, args.get(0)); + } + + @SuppressWarnings("unchecked") + private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { + return new TestCaseSupplier(fieldSupplier.name(), List.of(fieldSupplier.type()), () -> { + var fieldTypedData = fieldSupplier.get(); + var expected = fieldTypedData.multiRowData() + .stream() + .map(v -> (Comparable>) v) + .max(Comparator.naturalOrder()) + .orElse(null); + + return new TestCaseSupplier.TestCase( + List.of(fieldTypedData), + "Max[field=Attribute[channel=0]]", + fieldSupplier.type(), + equalTo(expected) + ); + }); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinTests.java new file mode 100644 index 0000000000000..6f59928059bec --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/MinTests.java @@ -0,0 +1,170 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function.aggregate; + +import com.carrotsearch.randomizedtesting.annotations.Name; +import com.carrotsearch.randomizedtesting.annotations.ParametersFactory; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; +import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; + +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import static org.hamcrest.Matchers.equalTo; + +public class MinTests extends AbstractAggregationTestCase { + public MinTests(@Name("TestCase") Supplier testCaseSupplier) { + this.testCase = testCaseSupplier.get(); + } + + @ParametersFactory + public static Iterable parameters() { + var suppliers = new ArrayList(); + + Stream.of( + MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true), + MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true), + MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true), + MultiRowTestCaseSupplier.dateCases(1, 1000), + MultiRowTestCaseSupplier.booleanCases(1, 1000) + ).flatMap(List::stream).map(MinTests::makeSupplier).collect(Collectors.toCollection(() -> suppliers)); + + suppliers.addAll( + List.of( + // Surrogates + new TestCaseSupplier( + List.of(DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(5, 8, -2, 0, 200), DataType.INTEGER, "field")), + "Min[field=Attribute[channel=0]]", + DataType.INTEGER, + equalTo(-2) + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, -2L, 0L, 200L), DataType.LONG, "field")), + "Min[field=Attribute[channel=0]]", + DataType.LONG, + equalTo(-2L) + ) + ), + new TestCaseSupplier( + List.of(DataType.DOUBLE), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(5., 8., -2., 0., 200.), DataType.DOUBLE, "field")), + "Min[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(-2.) + ) + ), + new TestCaseSupplier( + List.of(DataType.DATETIME), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.DATETIME, "field")), + "Min[field=Attribute[channel=0]]", + DataType.DATETIME, + equalTo(0L) + ) + ), + new TestCaseSupplier( + List.of(DataType.BOOLEAN), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(true, false, false, true), DataType.BOOLEAN, "field")), + "Min[field=Attribute[channel=0]]", + DataType.BOOLEAN, + equalTo(false) + ) + ), + + // Folding + new TestCaseSupplier( + List.of(DataType.INTEGER), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field")), + "Min[field=Attribute[channel=0]]", + DataType.INTEGER, + equalTo(200) + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.LONG, "field")), + "Min[field=Attribute[channel=0]]", + DataType.LONG, + equalTo(200L) + ) + ), + new TestCaseSupplier( + List.of(DataType.DOUBLE), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200.), DataType.DOUBLE, "field")), + "Min[field=Attribute[channel=0]]", + DataType.DOUBLE, + equalTo(200.) + ) + ), + new TestCaseSupplier( + List.of(DataType.DATETIME), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.DATETIME, "field")), + "Min[field=Attribute[channel=0]]", + DataType.DATETIME, + equalTo(200L) + ) + ), + new TestCaseSupplier( + List.of(DataType.BOOLEAN), + () -> new TestCaseSupplier.TestCase( + List.of(TestCaseSupplier.TypedData.multiRow(List.of(true), DataType.BOOLEAN, "field")), + "Min[field=Attribute[channel=0]]", + DataType.BOOLEAN, + equalTo(true) + ) + ) + ) + ); + + return parameterSuppliersFromTypedDataWithDefaultChecks(suppliers); + } + + @Override + protected Expression build(Source source, List args) { + return new Min(source, args.get(0)); + } + + @SuppressWarnings("unchecked") + private static TestCaseSupplier makeSupplier(TestCaseSupplier.TypedDataSupplier fieldSupplier) { + return new TestCaseSupplier(fieldSupplier.name(), List.of(fieldSupplier.type()), () -> { + var fieldTypedData = fieldSupplier.get(); + var expected = fieldTypedData.multiRowData() + .stream() + .map(v -> (Comparable>) v) + .min(Comparator.naturalOrder()) + .orElse(null); + + return new TestCaseSupplier.TestCase( + List.of(fieldTypedData), + "Min[field=Attribute[channel=0]]", + fieldSupplier.type(), + equalTo(expected) + ); + }); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java index 7b77decb560a9..c0c23ce29301e 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/aggregate/TopTests.java @@ -15,10 +15,15 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.AbstractAggregationTestCase; +import org.elasticsearch.xpack.esql.expression.function.MultiRowTestCaseSupplier; import org.elasticsearch.xpack.esql.expression.function.TestCaseSupplier; +import java.util.ArrayList; +import java.util.Comparator; import java.util.List; import java.util.function.Supplier; +import java.util.stream.Collectors; +import java.util.stream.Stream; import static org.hamcrest.Matchers.equalTo; @@ -29,212 +34,176 @@ public TopTests(@Name("TestCase") Supplier testCaseSu @ParametersFactory public static Iterable parameters() { - var suppliers = List.of( - // All types - new TestCaseSupplier(List.of(DataType.INTEGER, DataType.INTEGER, DataType.KEYWORD), () -> { - var limit = randomIntBetween(2, 4); - return new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5, 8, -2, 0, 200), DataType.INTEGER, "field"), - new TestCaseSupplier.TypedData(limit, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.INTEGER, - equalTo(List.of(200, 8, 5, 0).subList(0, limit)) - ); - }), - new TestCaseSupplier(List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), () -> { - var limit = randomIntBetween(2, 4); - return new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, -2L, 0L, 200L), DataType.LONG, "field"), - new TestCaseSupplier.TypedData(limit, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.LONG, - equalTo(List.of(200L, 8L, 5L, 0L).subList(0, limit)) - ); - }), - new TestCaseSupplier(List.of(DataType.DOUBLE, DataType.INTEGER, DataType.KEYWORD), () -> { - var limit = randomIntBetween(2, 4); - return new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5., 8., -2., 0., 200.), DataType.DOUBLE, "field"), - new TestCaseSupplier.TypedData(limit, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.DOUBLE, - equalTo(List.of(200., 8., 5., 0.).subList(0, limit)) - ); - }), - new TestCaseSupplier(List.of(DataType.DATETIME, DataType.INTEGER, DataType.KEYWORD), () -> { - var limit = randomIntBetween(2, 4); - return new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, -2L, 0L, 200L), DataType.DATETIME, "field"), - new TestCaseSupplier.TypedData(limit, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.DATETIME, - equalTo(List.of(200L, 8L, 5L, 0L).subList(0, limit)) - ); - }), + var suppliers = new ArrayList(); - // Surrogates - new TestCaseSupplier( - List.of(DataType.INTEGER, DataType.INTEGER, DataType.KEYWORD), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5, 8, -2, 0, 200), DataType.INTEGER, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.INTEGER, - equalTo(200) + for (var limitCaseSupplier : TestCaseSupplier.intCases(1, 1000, false)) { + for (String order : List.of("asc", "desc")) { + Stream.of( + MultiRowTestCaseSupplier.intCases(1, 1000, Integer.MIN_VALUE, Integer.MAX_VALUE, true), + MultiRowTestCaseSupplier.longCases(1, 1000, Long.MIN_VALUE, Long.MAX_VALUE, true), + MultiRowTestCaseSupplier.doubleCases(1, 1000, -Double.MAX_VALUE, Double.MAX_VALUE, true), + MultiRowTestCaseSupplier.dateCases(1, 1000) ) - ), - new TestCaseSupplier( - List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, -2L, 0L, 200L), DataType.LONG, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.LONG, - equalTo(200L) - ) - ), - new TestCaseSupplier( - List.of(DataType.DOUBLE, DataType.INTEGER, DataType.KEYWORD), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5., 8., -2., 0., 200.), DataType.DOUBLE, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.DOUBLE, - equalTo(200.) - ) - ), - new TestCaseSupplier( - List.of(DataType.DATETIME, DataType.INTEGER, DataType.KEYWORD), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.DATETIME, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.DATETIME, - equalTo(200L) - ) - ), + .flatMap(List::stream) + .map(fieldCaseSupplier -> TopTests.makeSupplier(fieldCaseSupplier, limitCaseSupplier, order)) + .collect(Collectors.toCollection(() -> suppliers)); + } + } - // Folding - new TestCaseSupplier( - List.of(DataType.INTEGER, DataType.INTEGER, DataType.KEYWORD), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.INTEGER, - equalTo(200) - ) - ), - new TestCaseSupplier( - List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.LONG, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.LONG, - equalTo(200L) - ) - ), - new TestCaseSupplier( - List.of(DataType.DOUBLE, DataType.INTEGER, DataType.KEYWORD), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(200.), DataType.DOUBLE, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.DOUBLE, - equalTo(200.) - ) - ), - new TestCaseSupplier( - List.of(DataType.DATETIME, DataType.INTEGER, DataType.KEYWORD), - () -> new TestCaseSupplier.TestCase( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.DATETIME, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", - DataType.DATETIME, - equalTo(200L) - ) - ), + suppliers.addAll( + List.of( + // Surrogates + new TestCaseSupplier( + List.of(DataType.INTEGER, DataType.INTEGER, DataType.KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5, 8, -2, 0, 200), DataType.INTEGER, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + DataType.INTEGER, + equalTo(200) + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, -2L, 0L, 200L), DataType.LONG, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + DataType.LONG, + equalTo(200L) + ) + ), + new TestCaseSupplier( + List.of(DataType.DOUBLE, DataType.INTEGER, DataType.KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5., 8., -2., 0., 200.), DataType.DOUBLE, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + DataType.DOUBLE, + equalTo(200.) + ) + ), + new TestCaseSupplier( + List.of(DataType.DATETIME, DataType.INTEGER, DataType.KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.DATETIME, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + DataType.DATETIME, + equalTo(200L) + ) + ), - // Resolution errors - new TestCaseSupplier( - List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), - () -> TestCaseSupplier.TestCase.typeError( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.LONG, "field"), - new TestCaseSupplier.TypedData(0, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "Limit must be greater than 0 in [], found [0]" - ) - ), - new TestCaseSupplier( - List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), - () -> TestCaseSupplier.TestCase.typeError( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.LONG, "field"), - new TestCaseSupplier.TypedData(2, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("wrong-order"), DataType.KEYWORD, "order").forceLiteral() - ), - "Invalid order value in [], expected [ASC, DESC] but got [wrong-order]" - ) - ), - new TestCaseSupplier( - List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), - () -> TestCaseSupplier.TestCase.typeError( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.LONG, "field"), - new TestCaseSupplier.TypedData(null, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() - ), - "second argument of [] cannot be null, received [limit]" - ) - ), - new TestCaseSupplier( - List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), - () -> TestCaseSupplier.TestCase.typeError( - List.of( - TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.LONG, "field"), - new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), - new TestCaseSupplier.TypedData(null, DataType.KEYWORD, "order").forceLiteral() - ), - "third argument of [] cannot be null, received [order]" + // Folding + new TestCaseSupplier( + List.of(DataType.INTEGER, DataType.INTEGER, DataType.KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(200), DataType.INTEGER, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + DataType.INTEGER, + equalTo(200) + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.LONG, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + DataType.LONG, + equalTo(200L) + ) + ), + new TestCaseSupplier( + List.of(DataType.DOUBLE, DataType.INTEGER, DataType.KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(200.), DataType.DOUBLE, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + DataType.DOUBLE, + equalTo(200.) + ) + ), + new TestCaseSupplier( + List.of(DataType.DATETIME, DataType.INTEGER, DataType.KEYWORD), + () -> new TestCaseSupplier.TestCase( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(200L), DataType.DATETIME, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + DataType.DATETIME, + equalTo(200L) + ) + ), + + // Resolution errors + new TestCaseSupplier( + List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), + () -> TestCaseSupplier.TestCase.typeError( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.LONG, "field"), + new TestCaseSupplier.TypedData(0, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "Limit must be greater than 0 in [], found [0]" + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), + () -> TestCaseSupplier.TestCase.typeError( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.LONG, "field"), + new TestCaseSupplier.TypedData(2, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("wrong-order"), DataType.KEYWORD, "order").forceLiteral() + ), + "Invalid order value in [], expected [ASC, DESC] but got [wrong-order]" + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), + () -> TestCaseSupplier.TestCase.typeError( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.LONG, "field"), + new TestCaseSupplier.TypedData(null, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(new BytesRef("desc"), DataType.KEYWORD, "order").forceLiteral() + ), + "second argument of [] cannot be null, received [limit]" + ) + ), + new TestCaseSupplier( + List.of(DataType.LONG, DataType.INTEGER, DataType.KEYWORD), + () -> TestCaseSupplier.TestCase.typeError( + List.of( + TestCaseSupplier.TypedData.multiRow(List.of(5L, 8L, 2L, 0L, 200L), DataType.LONG, "field"), + new TestCaseSupplier.TypedData(1, DataType.INTEGER, "limit").forceLiteral(), + new TestCaseSupplier.TypedData(null, DataType.KEYWORD, "order").forceLiteral() + ), + "third argument of [] cannot be null, received [order]" + ) ) ) ); @@ -246,4 +215,34 @@ public static Iterable parameters() { protected Expression build(Source source, List args) { return new Top(source, args.get(0), args.get(1), args.get(2)); } + + @SuppressWarnings("unchecked") + private static TestCaseSupplier makeSupplier( + TestCaseSupplier.TypedDataSupplier fieldSupplier, + TestCaseSupplier.TypedDataSupplier limitCaseSupplier, + String order + ) { + return new TestCaseSupplier(List.of(fieldSupplier.type(), DataType.INTEGER, DataType.KEYWORD), () -> { + var fieldTypedData = fieldSupplier.get(); + var limitTypedData = limitCaseSupplier.get().forceLiteral(); + var limit = (int) limitTypedData.getValue(); + var expected = fieldTypedData.multiRowData() + .stream() + .map(v -> (Comparable>) v) + .sorted(order.equals("asc") ? Comparator.naturalOrder() : Comparator.reverseOrder()) + .limit(limit) + .toList(); + + return new TestCaseSupplier.TestCase( + List.of( + fieldTypedData, + limitTypedData, + new TestCaseSupplier.TypedData(new BytesRef(order), DataType.KEYWORD, order + " order").forceLiteral() + ), + "Top[field=Attribute[channel=0], limit=Attribute[channel=1], order=Attribute[channel=2]]", + fieldSupplier.type(), + equalTo(expected.size() == 1 ? expected.get(0) : expected) + ); + }); + } } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSortTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSortTests.java index a085c0acfa25d..15c81557961f1 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSortTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/scalar/multivalue/MvSortTests.java @@ -12,7 +12,9 @@ import org.apache.lucene.util.BytesRef; import org.elasticsearch.compute.data.ElementType; +import org.elasticsearch.compute.operator.DriverContext; import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.function.AbstractScalarFunctionTestCase; @@ -183,6 +185,22 @@ private static void bytesRefs(List suppliers) { })); } + public void testInvalidOrder() { + String invalidOrder = randomAlphaOfLength(10); + DriverContext driverContext = driverContext(); + IllegalArgumentException e = expectThrows( + IllegalArgumentException.class, + () -> evaluator( + new MvSort( + Source.EMPTY, + field("str", DataType.DATETIME), + new Literal(Source.EMPTY, new BytesRef(invalidOrder), DataType.KEYWORD) + ) + ).get(driverContext) + ); + assertThat(e.getMessage(), equalTo("Invalid order value in [], expected one of [ASC, DESC] but got [" + invalidOrder + "]")); + } + @Override public void testSimpleWithNulls() { assumeFalse("test case is invalid", false); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/AbstractBinaryOperatorTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/AbstractBinaryOperatorTestCase.java index a9663f9e37852..974c8703b2a09 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/AbstractBinaryOperatorTestCase.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/predicate/operator/AbstractBinaryOperatorTestCase.java @@ -9,7 +9,7 @@ import org.elasticsearch.compute.data.Block; import org.elasticsearch.xpack.esql.analysis.Verifier; -import org.elasticsearch.xpack.esql.core.common.Failure; +import org.elasticsearch.xpack.esql.common.Failure; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryOperator; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java index 5a398ed3e4370..55691526ea428 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java @@ -24,10 +24,6 @@ import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.ArithmeticOperation; import org.elasticsearch.xpack.esql.core.index.EsIndex; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.EsField; @@ -50,9 +46,13 @@ import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.Grok; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Lookup; import org.elasticsearch.xpack.esql.plan.logical.MvExpand; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.TopN; import org.elasticsearch.xpack.esql.plan.logical.join.Join; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInputTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInputTests.java index 5788f218564c9..55763d9ec6e7b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInputTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInputTests.java @@ -10,10 +10,10 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.NameId; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.session.EsqlConfiguration; import java.util.ArrayList; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java index af6c065abbeee..2049fd5592d82 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LocalLogicalPlanOptimizerTests.java @@ -24,10 +24,6 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; import org.elasticsearch.xpack.esql.core.index.EsIndex; import org.elasticsearch.xpack.esql.core.index.IndexResolution; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -39,9 +35,13 @@ import org.elasticsearch.xpack.esql.parser.EsqlParser; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.MvExpand; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.Row; +import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject; import org.elasticsearch.xpack.esql.plan.logical.local.LocalRelation; import org.elasticsearch.xpack.esql.plan.logical.local.LocalSupplier; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index 6a9e7a4000734..dea3a974fbd5a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -42,11 +42,6 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern; import org.elasticsearch.xpack.esql.core.index.EsIndex; import org.elasticsearch.xpack.esql.core.index.IndexResolution; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; -import org.elasticsearch.xpack.esql.core.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.EsField; @@ -122,11 +117,16 @@ import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.Grok; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.MvExpand; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.Row; import org.elasticsearch.xpack.esql.plan.logical.TopN; +import org.elasticsearch.xpack.esql.plan.logical.UnaryPlan; import org.elasticsearch.xpack.esql.plan.logical.join.Join; import org.elasticsearch.xpack.esql.plan.logical.join.JoinType; import org.elasticsearch.xpack.esql.plan.logical.local.EsqlProject; @@ -5477,6 +5477,99 @@ METRICS k8s avg(round(1.05 * rate(network.total_bytes_in))) BY bucket(@timestamp assertThat(Expressions.attribute(values.field()).name(), equalTo("cluster")); } + public void testMetricsWithoutRate() { + assumeTrue("requires snapshot builds", Build.current().isSnapshot()); + List queries = List.of(""" + METRICS k8s count(to_long(network.total_bytes_in)) BY bucket(@timestamp, 1 minute) + | LIMIT 10 + """, """ + METRICS k8s | STATS count(to_long(network.total_bytes_in)) BY bucket(@timestamp, 1 minute) + | LIMIT 10 + """, """ + FROM k8s | STATS count(to_long(network.total_bytes_in)) BY bucket(@timestamp, 1 minute) + | LIMIT 10 + """); + List plans = new ArrayList<>(); + for (String query : queries) { + var plan = logicalOptimizer.optimize(metricsAnalyzer.analyze(parser.createStatement(query))); + plans.add(plan); + } + for (LogicalPlan plan : plans) { + Limit limit = as(plan, Limit.class); + Aggregate aggregate = as(limit.child(), Aggregate.class); + assertThat(aggregate.aggregateType(), equalTo(Aggregate.AggregateType.STANDARD)); + assertThat(aggregate.aggregates(), hasSize(2)); + assertThat(aggregate.groupings(), hasSize(1)); + Eval eval = as(aggregate.child(), Eval.class); + assertThat(eval.fields(), hasSize(2)); + assertThat(Alias.unwrap(eval.fields().get(0)), instanceOf(Bucket.class)); + assertThat(Alias.unwrap(eval.fields().get(1)), instanceOf(ToLong.class)); + EsRelation relation = as(eval.child(), EsRelation.class); + assertThat(relation.indexMode(), equalTo(IndexMode.STANDARD)); + } + for (int i = 1; i < plans.size(); i++) { + assertThat(plans.get(i), equalTo(plans.get(0))); + } + } + + public void testRateInStats() { + assumeTrue("requires snapshot builds", Build.current().isSnapshot()); + var query = """ + METRICS k8s | STATS max(rate(network.total_bytes_in)) BY bucket(@timestamp, 1 minute) + | LIMIT 10 + """; + VerificationException error = expectThrows( + VerificationException.class, + () -> logicalOptimizer.optimize(metricsAnalyzer.analyze(parser.createStatement(query))) + ); + assertThat(error.getMessage(), equalTo(""" + Found 1 problem + line 1:25: the rate aggregate[rate(network.total_bytes_in)] can only be used within the metrics command""")); + } + + public void testMvSortInvalidOrder() { + VerificationException e = expectThrows(VerificationException.class, () -> plan(""" + from test + | EVAL sd = mv_sort(salary, "ABC") + """)); + assertTrue(e.getMessage().startsWith("Found ")); + final String header = "Found 1 problem\nline "; + assertEquals( + "2:29: Invalid order value in [mv_sort(salary, \"ABC\")], expected one of [ASC, DESC] but got [ABC]", + e.getMessage().substring(header.length()) + ); + + e = expectThrows(VerificationException.class, () -> plan(""" + from test + | EVAL order = "ABC", sd = mv_sort(salary, order) + """)); + assertTrue(e.getMessage().startsWith("Found ")); + assertEquals( + "2:16: Invalid order value in [mv_sort(salary, order)], expected one of [ASC, DESC] but got [ABC]", + e.getMessage().substring(header.length()) + ); + + e = expectThrows(VerificationException.class, () -> plan(""" + from test + | EVAL order = concat("d", "sc"), sd = mv_sort(salary, order) + """)); + assertTrue(e.getMessage().startsWith("Found ")); + assertEquals( + "2:16: Invalid order value in [mv_sort(salary, order)], expected one of [ASC, DESC] but got [dsc]", + e.getMessage().substring(header.length()) + ); + + IllegalArgumentException iae = expectThrows(IllegalArgumentException.class, () -> plan(""" + row v = [1, 2, 3] | EVAL sd = mv_sort(v, "dsc") + """)); + assertEquals("Invalid order value in [mv_sort(v, \"dsc\")], expected one of [ASC, DESC] but got [dsc]", iae.getMessage()); + + iae = expectThrows(IllegalArgumentException.class, () -> plan(""" + row v = [1, 2, 3], o = concat("d", "sc") | EVAL sd = mv_sort(v, o) + """)); + assertEquals("Invalid order value in [mv_sort(v, o)], expected one of [ASC, DESC] but got [dsc]", iae.getMessage()); + } + private Literal nullOf(DataType dataType) { return new Literal(Source.EMPTY, null, dataType); } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java deleted file mode 100644 index b550f6e6090da..0000000000000 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/OptimizerRulesTests.java +++ /dev/null @@ -1,860 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.esql.optimizer; - -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.esql.core.expression.Alias; -import org.elasticsearch.xpack.esql.core.expression.Expression; -import org.elasticsearch.xpack.esql.core.expression.Expressions; -import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; -import org.elasticsearch.xpack.esql.core.expression.Literal; -import org.elasticsearch.xpack.esql.core.expression.Nullability; -import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryOperator; -import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; -import org.elasticsearch.xpack.esql.core.expression.predicate.Range; -import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; -import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; -import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; -import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; -import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNull; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.Like; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.LikePattern; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLike; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardLike; -import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.FoldNull; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.PropagateNullable; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.tree.Source; -import org.elasticsearch.xpack.esql.core.type.DataType; -import org.elasticsearch.xpack.esql.core.util.StringUtils; -import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; -import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; -import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mod; -import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; -import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub; -import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; -import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; -import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; -import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; -import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan; -import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual; -import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals; -import org.elasticsearch.xpack.esql.optimizer.rules.BooleanFunctionEqualsElimination; -import org.elasticsearch.xpack.esql.optimizer.rules.CombineDisjunctionsToIn; -import org.elasticsearch.xpack.esql.optimizer.rules.ConstantFolding; -import org.elasticsearch.xpack.esql.optimizer.rules.LiteralsOnTheRight; -import org.elasticsearch.xpack.esql.optimizer.rules.PropagateEquals; -import org.elasticsearch.xpack.esql.optimizer.rules.ReplaceRegexMatch; - -import java.util.List; - -import static java.util.Arrays.asList; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.FIVE; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.FOUR; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.ONE; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.THREE; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.TWO; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.equalsOf; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOrEqualOf; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOf; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOrEqualOf; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.notEqualsOf; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.of; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.rangeOf; -import static org.elasticsearch.xpack.esql.EsqlTestUtils.relation; -import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; -import static org.elasticsearch.xpack.esql.core.expression.Literal.NULL; -import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; -import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; -import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; -import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; -import static org.hamcrest.Matchers.contains; - -public class OptimizerRulesTests extends ESTestCase { - private static final Expression DUMMY_EXPRESSION = - new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 0); - - // - // Constant folding - // - - public void testConstantFolding() { - Expression exp = new Add(EMPTY, TWO, THREE); - - assertTrue(exp.foldable()); - Expression result = new ConstantFolding().rule(exp); - assertTrue(result instanceof Literal); - assertEquals(5, ((Literal) result).value()); - - // check now with an alias - result = new ConstantFolding().rule(new Alias(EMPTY, "a", exp)); - assertEquals("a", Expressions.name(result)); - assertEquals(Alias.class, result.getClass()); - } - - public void testConstantFoldingBinaryComparison() { - assertEquals(FALSE, new ConstantFolding().rule(greaterThanOf(TWO, THREE)).canonical()); - assertEquals(FALSE, new ConstantFolding().rule(greaterThanOrEqualOf(TWO, THREE)).canonical()); - assertEquals(FALSE, new ConstantFolding().rule(equalsOf(TWO, THREE)).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(notEqualsOf(TWO, THREE)).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(lessThanOrEqualOf(TWO, THREE)).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(lessThanOf(TWO, THREE)).canonical()); - } - - public void testConstantFoldingBinaryLogic() { - assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, greaterThanOf(TWO, THREE), TRUE)).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, greaterThanOrEqualOf(TWO, THREE), TRUE)).canonical()); - } - - public void testConstantFoldingBinaryLogic_WithNullHandling() { - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, NULL, TRUE)).canonical().nullable()); - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, TRUE, NULL)).canonical().nullable()); - assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, NULL, FALSE)).canonical()); - assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, FALSE, NULL)).canonical()); - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, NULL, NULL)).canonical().nullable()); - - assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, TRUE)).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, TRUE, NULL)).canonical()); - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, FALSE)).canonical().nullable()); - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, FALSE, NULL)).canonical().nullable()); - assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, NULL)).canonical().nullable()); - } - - public void testConstantFoldingRange() { - assertEquals(true, new ConstantFolding().rule(rangeOf(FIVE, FIVE, true, new Literal(EMPTY, 10, DataType.INTEGER), false)).fold()); - assertEquals(false, new ConstantFolding().rule(rangeOf(FIVE, FIVE, false, new Literal(EMPTY, 10, DataType.INTEGER), false)).fold()); - } - - public void testConstantNot() { - assertEquals(FALSE, new ConstantFolding().rule(new Not(EMPTY, TRUE))); - assertEquals(TRUE, new ConstantFolding().rule(new Not(EMPTY, FALSE))); - } - - public void testConstantFoldingLikes() { - assertEquals(TRUE, new ConstantFolding().rule(new Like(EMPTY, of("test_emp"), new LikePattern("test%", (char) 0))).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(new WildcardLike(EMPTY, of("test_emp"), new WildcardPattern("test*"))).canonical()); - assertEquals(TRUE, new ConstantFolding().rule(new RLike(EMPTY, of("test_emp"), new RLikePattern("test.emp"))).canonical()); - } - - public void testArithmeticFolding() { - assertEquals(10, foldOperator(new Add(EMPTY, new Literal(EMPTY, 7, DataType.INTEGER), THREE))); - assertEquals(4, foldOperator(new Sub(EMPTY, new Literal(EMPTY, 7, DataType.INTEGER), THREE))); - assertEquals(21, foldOperator(new Mul(EMPTY, new Literal(EMPTY, 7, DataType.INTEGER), THREE))); - assertEquals(2, foldOperator(new Div(EMPTY, new Literal(EMPTY, 7, DataType.INTEGER), THREE))); - assertEquals(1, foldOperator(new Mod(EMPTY, new Literal(EMPTY, 7, DataType.INTEGER), THREE))); - } - - private static Object foldOperator(BinaryOperator b) { - return ((Literal) new ConstantFolding().rule(b)).value(); - } - - // - // CombineDisjunction in Equals - // - public void testTwoEqualsWithOr() { - FieldAttribute fa = getFieldAttribute(); - - Or or = new Or(EMPTY, equalsOf(fa, ONE), equalsOf(fa, TWO)); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(In.class, e.getClass()); - In in = (In) e; - assertEquals(fa, in.value()); - assertThat(in.list(), contains(ONE, TWO)); - } - - public void testTwoEqualsWithSameValue() { - FieldAttribute fa = getFieldAttribute(); - - Or or = new Or(EMPTY, equalsOf(fa, ONE), equalsOf(fa, ONE)); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(Equals.class, e.getClass()); - Equals eq = (Equals) e; - assertEquals(fa, eq.left()); - assertEquals(ONE, eq.right()); - } - - public void testOneEqualsOneIn() { - FieldAttribute fa = getFieldAttribute(); - - Or or = new Or(EMPTY, equalsOf(fa, ONE), new In(EMPTY, fa, List.of(TWO))); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(In.class, e.getClass()); - In in = (In) e; - assertEquals(fa, in.value()); - assertThat(in.list(), contains(ONE, TWO)); - } - - public void testOneEqualsOneInWithSameValue() { - FieldAttribute fa = getFieldAttribute(); - - Or or = new Or(EMPTY, equalsOf(fa, ONE), new In(EMPTY, fa, asList(ONE, TWO))); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(In.class, e.getClass()); - In in = (In) e; - assertEquals(fa, in.value()); - assertThat(in.list(), contains(ONE, TWO)); - } - - public void testSingleValueInToEquals() { - FieldAttribute fa = getFieldAttribute(); - - Equals equals = equalsOf(fa, ONE); - Or or = new Or(EMPTY, equals, new In(EMPTY, fa, List.of(ONE))); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(equals, e); - } - - public void testEqualsBehindAnd() { - FieldAttribute fa = getFieldAttribute(); - - And and = new And(EMPTY, equalsOf(fa, ONE), equalsOf(fa, TWO)); - Filter dummy = new Filter(EMPTY, relation(), and); - LogicalPlan transformed = new CombineDisjunctionsToIn().apply(dummy); - assertSame(dummy, transformed); - assertEquals(and, ((Filter) transformed).condition()); - } - - public void testTwoEqualsDifferentFields() { - FieldAttribute fieldOne = getFieldAttribute("ONE"); - FieldAttribute fieldTwo = getFieldAttribute("TWO"); - - Or or = new Or(EMPTY, equalsOf(fieldOne, ONE), equalsOf(fieldTwo, TWO)); - Expression e = new CombineDisjunctionsToIn().rule(or); - assertEquals(or, e); - } - - public void testMultipleIn() { - FieldAttribute fa = getFieldAttribute(); - - Or firstOr = new Or(EMPTY, new In(EMPTY, fa, List.of(ONE)), new In(EMPTY, fa, List.of(TWO))); - Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, List.of(THREE))); - Expression e = new CombineDisjunctionsToIn().rule(secondOr); - assertEquals(In.class, e.getClass()); - In in = (In) e; - assertEquals(fa, in.value()); - assertThat(in.list(), contains(ONE, TWO, THREE)); - } - - public void testOrWithNonCombinableExpressions() { - FieldAttribute fa = getFieldAttribute(); - - Or firstOr = new Or(EMPTY, new In(EMPTY, fa, List.of(ONE)), lessThanOf(fa, TWO)); - Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, List.of(THREE))); - Expression e = new CombineDisjunctionsToIn().rule(secondOr); - assertEquals(Or.class, e.getClass()); - Or or = (Or) e; - assertEquals(or.left(), firstOr.right()); - assertEquals(In.class, or.right().getClass()); - In in = (In) or.right(); - assertEquals(fa, in.value()); - assertThat(in.list(), contains(ONE, THREE)); - } - - // Test BooleanFunctionEqualsElimination - public void testBoolEqualsSimplificationOnExpressions() { - BooleanFunctionEqualsElimination s = new BooleanFunctionEqualsElimination(); - Expression exp = new GreaterThan(EMPTY, getFieldAttribute(), new Literal(EMPTY, 0, DataType.INTEGER), null); - - assertEquals(exp, s.rule(new Equals(EMPTY, exp, TRUE))); - // TODO: Replace use of QL Not with ESQL Not - assertEquals(new Not(EMPTY, exp), s.rule(new Equals(EMPTY, exp, FALSE))); - } - - public void testBoolEqualsSimplificationOnFields() { - BooleanFunctionEqualsElimination s = new BooleanFunctionEqualsElimination(); - - FieldAttribute field = getFieldAttribute(); - - List comparisons = asList( - new Equals(EMPTY, field, TRUE), - new Equals(EMPTY, field, FALSE), - notEqualsOf(field, TRUE), - notEqualsOf(field, FALSE), - new Equals(EMPTY, NULL, TRUE), - new Equals(EMPTY, NULL, FALSE), - notEqualsOf(NULL, TRUE), - notEqualsOf(NULL, FALSE) - ); - - for (BinaryComparison comparison : comparisons) { - assertEquals(comparison, s.rule(comparison)); - } - } - - // Test Propagate Equals - - // a == 1 AND a == 2 -> FALSE - public void testDualEqualsConjunction() { - FieldAttribute fa = getFieldAttribute(); - Equals eq1 = equalsOf(fa, ONE); - Equals eq2 = equalsOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq1, eq2)); - assertEquals(FALSE, exp); - } - - // 1 < a < 10 AND a == 10 -> FALSE - public void testEliminateRangeByEqualsOutsideInterval() { - FieldAttribute fa = getFieldAttribute(); - Equals eq1 = equalsOf(fa, new Literal(EMPTY, 10, DataType.INTEGER)); - Range r = rangeOf(fa, ONE, false, new Literal(EMPTY, 10, DataType.INTEGER), false); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq1, r)); - assertEquals(FALSE, exp); - } - - // a != 3 AND a = 3 -> FALSE - public void testPropagateEquals_VarNeq3AndVarEq3() { - FieldAttribute fa = getFieldAttribute(); - NotEquals neq = notEqualsOf(fa, THREE); - Equals eq = equalsOf(fa, THREE); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, neq, eq)); - assertEquals(FALSE, exp); - } - - // a != 4 AND a = 3 -> a = 3 - public void testPropagateEquals_VarNeq4AndVarEq3() { - FieldAttribute fa = getFieldAttribute(); - NotEquals neq = notEqualsOf(fa, FOUR); - Equals eq = equalsOf(fa, THREE); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, neq, eq)); - assertEquals(Equals.class, exp.getClass()); - assertEquals(eq, exp); - } - - // a = 2 AND a < 2 -> FALSE - public void testPropagateEquals_VarEq2AndVarLt2() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - LessThan lt = lessThanOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, lt)); - assertEquals(FALSE, exp); - } - - // a = 2 AND a <= 2 -> a = 2 - public void testPropagateEquals_VarEq2AndVarLte2() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - LessThanOrEqual lt = lessThanOrEqualOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, lt)); - assertEquals(eq, exp); - } - - // a = 2 AND a <= 1 -> FALSE - public void testPropagateEquals_VarEq2AndVarLte1() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - LessThanOrEqual lt = lessThanOrEqualOf(fa, ONE); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, lt)); - assertEquals(FALSE, exp); - } - - // a = 2 AND a > 2 -> FALSE - public void testPropagateEquals_VarEq2AndVarGt2() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - GreaterThan gt = greaterThanOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, gt)); - assertEquals(FALSE, exp); - } - - // a = 2 AND a >= 2 -> a = 2 - public void testPropagateEquals_VarEq2AndVarGte2() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, gte)); - assertEquals(eq, exp); - } - - // a = 2 AND a > 3 -> FALSE - public void testPropagateEquals_VarEq2AndVarLt3() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - GreaterThan gt = greaterThanOf(fa, THREE); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq, gt)); - assertEquals(FALSE, exp); - } - - // a = 2 AND a < 3 AND a > 1 AND a != 4 -> a = 2 - public void testPropagateEquals_VarEq2AndVarLt3AndVarGt1AndVarNeq4() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - LessThan lt = lessThanOf(fa, THREE); - GreaterThan gt = greaterThanOf(fa, ONE); - NotEquals neq = notEqualsOf(fa, FOUR); - - PropagateEquals rule = new PropagateEquals(); - Expression and = Predicates.combineAnd(asList(eq, lt, gt, neq)); - Expression exp = rule.rule((And) and); - assertEquals(eq, exp); - } - - // a = 2 AND 1 < a < 3 AND a > 0 AND a != 4 -> a = 2 - public void testPropagateEquals_VarEq2AndVarRangeGt1Lt3AndVarGt0AndVarNeq4() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - Range range = rangeOf(fa, ONE, false, THREE, false); - GreaterThan gt = greaterThanOf(fa, new Literal(EMPTY, 0, DataType.INTEGER)); - NotEquals neq = notEqualsOf(fa, FOUR); - - PropagateEquals rule = new PropagateEquals(); - Expression and = Predicates.combineAnd(asList(eq, range, gt, neq)); - Expression exp = rule.rule((And) and); - assertEquals(eq, exp); - } - - // a = 2 OR a > 1 -> a > 1 - public void testPropagateEquals_VarEq2OrVarGt1() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - GreaterThan gt = greaterThanOf(fa, ONE); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, gt)); - assertEquals(gt, exp); - } - - // a = 2 OR a > 2 -> a >= 2 - public void testPropagateEquals_VarEq2OrVarGte2() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - GreaterThan gt = greaterThanOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, gt)); - assertEquals(GreaterThanOrEqual.class, exp.getClass()); - GreaterThanOrEqual gte = (GreaterThanOrEqual) exp; - assertEquals(TWO, gte.right()); - } - - // a = 2 OR a < 3 -> a < 3 - public void testPropagateEquals_VarEq2OrVarLt3() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - LessThan lt = lessThanOf(fa, THREE); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, lt)); - assertEquals(lt, exp); - } - - // a = 3 OR a < 3 -> a <= 3 - public void testPropagateEquals_VarEq3OrVarLt3() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, THREE); - LessThan lt = lessThanOf(fa, THREE); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, lt)); - assertEquals(LessThanOrEqual.class, exp.getClass()); - LessThanOrEqual lte = (LessThanOrEqual) exp; - assertEquals(THREE, lte.right()); - } - - // a = 2 OR 1 < a < 3 -> 1 < a < 3 - public void testPropagateEquals_VarEq2OrVarRangeGt1Lt3() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - Range range = rangeOf(fa, ONE, false, THREE, false); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, range)); - assertEquals(range, exp); - } - - // a = 2 OR 2 < a < 3 -> 2 <= a < 3 - public void testPropagateEquals_VarEq2OrVarRangeGt2Lt3() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - Range range = rangeOf(fa, TWO, false, THREE, false); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, range)); - assertEquals(Range.class, exp.getClass()); - Range r = (Range) exp; - assertEquals(TWO, r.lower()); - assertTrue(r.includeLower()); - assertEquals(THREE, r.upper()); - assertFalse(r.includeUpper()); - } - - // a = 3 OR 2 < a < 3 -> 2 < a <= 3 - public void testPropagateEquals_VarEq3OrVarRangeGt2Lt3() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, THREE); - Range range = rangeOf(fa, TWO, false, THREE, false); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, range)); - assertEquals(Range.class, exp.getClass()); - Range r = (Range) exp; - assertEquals(TWO, r.lower()); - assertFalse(r.includeLower()); - assertEquals(THREE, r.upper()); - assertTrue(r.includeUpper()); - } - - // a = 2 OR a != 2 -> TRUE - public void testPropagateEquals_VarEq2OrVarNeq2() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - NotEquals neq = notEqualsOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, neq)); - assertEquals(TRUE, exp); - } - - // a = 2 OR a != 5 -> a != 5 - public void testPropagateEquals_VarEq2OrVarNeq5() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - NotEquals neq = notEqualsOf(fa, FIVE); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new Or(EMPTY, eq, neq)); - assertEquals(NotEquals.class, exp.getClass()); - NotEquals ne = (NotEquals) exp; - assertEquals(FIVE, ne.right()); - } - - // a = 2 OR 3 < a < 4 OR a > 2 OR a!= 2 -> TRUE - public void testPropagateEquals_VarEq2OrVarRangeGt3Lt4OrVarGt2OrVarNe2() { - FieldAttribute fa = getFieldAttribute(); - Equals eq = equalsOf(fa, TWO); - Range range = rangeOf(fa, THREE, false, FOUR, false); - GreaterThan gt = greaterThanOf(fa, TWO); - NotEquals neq = notEqualsOf(fa, TWO); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule((Or) Predicates.combineOr(asList(eq, range, neq, gt))); - assertEquals(TRUE, exp); - } - - // a == 1 AND a == 2 -> nop for date/time fields - public void testPropagateEquals_ignoreDateTimeFields() { - FieldAttribute fa = getFieldAttribute("a", DataType.DATETIME); - Equals eq1 = equalsOf(fa, ONE); - Equals eq2 = equalsOf(fa, TWO); - And and = new And(EMPTY, eq1, eq2); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(and); - assertEquals(and, exp); - } - - // 1 <= a < 10 AND a == 1 -> a == 1 - public void testEliminateRangeByEqualsInInterval() { - FieldAttribute fa = getFieldAttribute(); - Equals eq1 = equalsOf(fa, ONE); - Range r = rangeOf(fa, ONE, true, new Literal(EMPTY, 10, DataType.INTEGER), false); - - PropagateEquals rule = new PropagateEquals(); - Expression exp = rule.rule(new And(EMPTY, eq1, r)); - assertEquals(eq1, exp); - } - // - // Null folding - - public void testNullFoldingIsNull() { - FoldNull foldNull = new FoldNull(); - assertEquals(true, foldNull.rule(new IsNull(EMPTY, NULL)).fold()); - assertEquals(false, foldNull.rule(new IsNull(EMPTY, TRUE)).fold()); - } - - public void testGenericNullableExpression() { - FoldNull rule = new FoldNull(); - // arithmetic - assertNullLiteral(rule.rule(new Add(EMPTY, getFieldAttribute(), NULL))); - // comparison - assertNullLiteral(rule.rule(greaterThanOf(getFieldAttribute(), NULL))); - // regex - assertNullLiteral(rule.rule(new RLike(EMPTY, NULL, new RLikePattern("123")))); - } - - public void testNullFoldingDoesNotApplyOnLogicalExpressions() { - org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.FoldNull rule = - new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.FoldNull(); - - Or or = new Or(EMPTY, NULL, TRUE); - assertEquals(or, rule.rule(or)); - or = new Or(EMPTY, NULL, NULL); - assertEquals(or, rule.rule(or)); - - And and = new And(EMPTY, NULL, TRUE); - assertEquals(and, rule.rule(and)); - and = new And(EMPTY, NULL, NULL); - assertEquals(and, rule.rule(and)); - } - - // - // Propagate nullability (IS NULL / IS NOT NULL) - // - - // a IS NULL AND a IS NOT NULL => false - public void testIsNullAndNotNull() { - FieldAttribute fa = getFieldAttribute(); - - And and = new And(EMPTY, new IsNull(EMPTY, fa), new IsNotNull(EMPTY, fa)); - assertEquals(FALSE, new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.PropagateNullable().rule(and)); - } - - // a IS NULL AND b IS NOT NULL AND c IS NULL AND d IS NOT NULL AND e IS NULL AND a IS NOT NULL => false - public void testIsNullAndNotNullMultiField() { - FieldAttribute fa = getFieldAttribute(); - - And andOne = new And(EMPTY, new IsNull(EMPTY, fa), new IsNotNull(EMPTY, getFieldAttribute())); - And andTwo = new And(EMPTY, new IsNull(EMPTY, getFieldAttribute()), new IsNotNull(EMPTY, getFieldAttribute())); - And andThree = new And(EMPTY, new IsNull(EMPTY, getFieldAttribute()), new IsNotNull(EMPTY, fa)); - - And and = new And(EMPTY, andOne, new And(EMPTY, andThree, andTwo)); - - assertEquals(FALSE, new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.PropagateNullable().rule(and)); - } - - // a IS NULL AND a > 1 => a IS NULL AND false - public void testIsNullAndComparison() { - FieldAttribute fa = getFieldAttribute(); - IsNull isNull = new IsNull(EMPTY, fa); - - And and = new And(EMPTY, isNull, greaterThanOf(fa, ONE)); - assertEquals(new And(EMPTY, isNull, nullOf(BOOLEAN)), new PropagateNullable().rule(and)); - } - - // a IS NULL AND b < 1 AND c < 1 AND a < 1 => a IS NULL AND b < 1 AND c < 1 => a IS NULL AND b < 1 AND c < 1 - public void testIsNullAndMultipleComparison() { - FieldAttribute fa = getFieldAttribute(); - IsNull isNull = new IsNull(EMPTY, fa); - - And nestedAnd = new And(EMPTY, lessThanOf(getFieldAttribute("b"), ONE), lessThanOf(getFieldAttribute("c"), ONE)); - And and = new And(EMPTY, isNull, nestedAnd); - And top = new And(EMPTY, and, lessThanOf(fa, ONE)); - - Expression optimized = new PropagateNullable().rule(top); - Expression expected = new And(EMPTY, and, nullOf(BOOLEAN)); - assertEquals(Predicates.splitAnd(expected), Predicates.splitAnd(optimized)); - } - - // ((a+1)/2) > 1 AND a + 2 AND a IS NULL AND b < 3 => NULL AND NULL AND a IS NULL AND b < 3 - public void testIsNullAndDeeplyNestedExpression() { - FieldAttribute fa = getFieldAttribute(); - IsNull isNull = new IsNull(EMPTY, fa); - - Expression nullified = new And( - EMPTY, - greaterThanOf(new Div(EMPTY, new Add(EMPTY, fa, ONE), TWO), ONE), - greaterThanOf(new Add(EMPTY, fa, TWO), ONE) - ); - Expression kept = new And(EMPTY, isNull, lessThanOf(getFieldAttribute("b"), THREE)); - And and = new And(EMPTY, nullified, kept); - - Expression optimized = new PropagateNullable().rule(and); - Expression expected = new And(EMPTY, new And(EMPTY, nullOf(BOOLEAN), nullOf(BOOLEAN)), kept); - - assertEquals(Predicates.splitAnd(expected), Predicates.splitAnd(optimized)); - } - - // a IS NULL OR a IS NOT NULL => no change - // a IS NULL OR a > 1 => no change - public void testIsNullInDisjunction() { - FieldAttribute fa = getFieldAttribute(); - - Or or = new Or(EMPTY, new IsNull(EMPTY, fa), new IsNotNull(EMPTY, fa)); - Filter dummy = new Filter(EMPTY, relation(), or); - LogicalPlan transformed = new PropagateNullable().apply(dummy); - assertSame(dummy, transformed); - assertEquals(or, ((Filter) transformed).condition()); - - or = new Or(EMPTY, new IsNull(EMPTY, fa), greaterThanOf(fa, ONE)); - dummy = new Filter(EMPTY, relation(), or); - transformed = new PropagateNullable().apply(dummy); - assertSame(dummy, transformed); - assertEquals(or, ((Filter) transformed).condition()); - } - - // a + 1 AND (a IS NULL OR a > 3) => no change - public void testIsNullDisjunction() { - FieldAttribute fa = getFieldAttribute(); - IsNull isNull = new IsNull(EMPTY, fa); - - Or or = new Or(EMPTY, isNull, greaterThanOf(fa, THREE)); - And and = new And(EMPTY, new Add(EMPTY, fa, ONE), or); - - assertEquals(and, new PropagateNullable().rule(and)); - } - - // - // Like / Regex - // - public void testMatchAllLikeToExist() { - for (String s : asList("%", "%%", "%%%")) { - LikePattern pattern = new LikePattern(s, (char) 0); - FieldAttribute fa = getFieldAttribute(); - Like l = new Like(EMPTY, fa, pattern); - Expression e = new ReplaceRegexMatch().rule(l); - assertEquals(IsNotNull.class, e.getClass()); - IsNotNull inn = (IsNotNull) e; - assertEquals(fa, inn.field()); - } - } - - public void testMatchAllWildcardLikeToExist() { - for (String s : asList("*", "**", "***")) { - WildcardPattern pattern = new WildcardPattern(s); - FieldAttribute fa = getFieldAttribute(); - WildcardLike l = new WildcardLike(EMPTY, fa, pattern); - Expression e = new ReplaceRegexMatch().rule(l); - assertEquals(IsNotNull.class, e.getClass()); - IsNotNull inn = (IsNotNull) e; - assertEquals(fa, inn.field()); - } - } - - public void testMatchAllRLikeToExist() { - RLikePattern pattern = new RLikePattern(".*"); - FieldAttribute fa = getFieldAttribute(); - RLike l = new RLike(EMPTY, fa, pattern); - Expression e = new ReplaceRegexMatch().rule(l); - assertEquals(IsNotNull.class, e.getClass()); - IsNotNull inn = (IsNotNull) e; - assertEquals(fa, inn.field()); - } - - public void testExactMatchLike() { - for (String s : asList("ab", "ab0%", "ab0_c")) { - LikePattern pattern = new LikePattern(s, '0'); - FieldAttribute fa = getFieldAttribute(); - Like l = new Like(EMPTY, fa, pattern); - Expression e = new ReplaceRegexMatch().rule(l); - assertEquals(Equals.class, e.getClass()); - Equals eq = (Equals) e; - assertEquals(fa, eq.left()); - assertEquals(s.replace("0", StringUtils.EMPTY), eq.right().fold()); - } - } - - public void testExactMatchWildcardLike() { - String s = "ab"; - WildcardPattern pattern = new WildcardPattern(s); - FieldAttribute fa = getFieldAttribute(); - WildcardLike l = new WildcardLike(EMPTY, fa, pattern); - Expression e = new ReplaceRegexMatch().rule(l); - assertEquals(Equals.class, e.getClass()); - Equals eq = (Equals) e; - assertEquals(fa, eq.left()); - assertEquals(s, eq.right().fold()); - } - - public void testExactMatchRLike() { - RLikePattern pattern = new RLikePattern("abc"); - FieldAttribute fa = getFieldAttribute(); - RLike l = new RLike(EMPTY, fa, pattern); - Expression e = new ReplaceRegexMatch().rule(l); - assertEquals(Equals.class, e.getClass()); - Equals eq = (Equals) e; - assertEquals(fa, eq.left()); - assertEquals("abc", eq.right().fold()); - } - - private void assertNullLiteral(Expression expression) { - assertEquals(Literal.class, expression.getClass()); - assertNull(expression.fold()); - } - - private IsNotNull isNotNull(Expression field) { - return new IsNotNull(EMPTY, field); - } - - private IsNull isNull(Expression field) { - return new IsNull(EMPTY, field); - } - - private Literal nullOf(DataType dataType) { - return new Literal(Source.EMPTY, null, dataType); - } - // - // Logical simplifications - // - - public void testLiteralsOnTheRight() { - Alias a = new Alias(EMPTY, "a", new Literal(EMPTY, 10, INTEGER)); - Expression result = new LiteralsOnTheRight().rule(equalsOf(FIVE, a)); - assertTrue(result instanceof Equals); - Equals eq = (Equals) result; - assertEquals(a, eq.left()); - assertEquals(FIVE, eq.right()); - - // Note: Null Equals test removed here - } - - public void testBoolSimplifyOr() { - org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.BooleanSimplification simplification = - new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.BooleanSimplification(); - - assertEquals(TRUE, simplification.rule(new Or(EMPTY, TRUE, TRUE))); - assertEquals(TRUE, simplification.rule(new Or(EMPTY, TRUE, DUMMY_EXPRESSION))); - assertEquals(TRUE, simplification.rule(new Or(EMPTY, DUMMY_EXPRESSION, TRUE))); - - assertEquals(FALSE, simplification.rule(new Or(EMPTY, FALSE, FALSE))); - assertEquals(DUMMY_EXPRESSION, simplification.rule(new Or(EMPTY, FALSE, DUMMY_EXPRESSION))); - assertEquals(DUMMY_EXPRESSION, simplification.rule(new Or(EMPTY, DUMMY_EXPRESSION, FALSE))); - } - - public void testBoolSimplifyAnd() { - org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.BooleanSimplification simplification = - new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.BooleanSimplification(); - - assertEquals(TRUE, simplification.rule(new And(EMPTY, TRUE, TRUE))); - assertEquals(DUMMY_EXPRESSION, simplification.rule(new And(EMPTY, TRUE, DUMMY_EXPRESSION))); - assertEquals(DUMMY_EXPRESSION, simplification.rule(new And(EMPTY, DUMMY_EXPRESSION, TRUE))); - - assertEquals(FALSE, simplification.rule(new And(EMPTY, FALSE, FALSE))); - assertEquals(FALSE, simplification.rule(new And(EMPTY, FALSE, DUMMY_EXPRESSION))); - assertEquals(FALSE, simplification.rule(new And(EMPTY, DUMMY_EXPRESSION, FALSE))); - } - - public void testBoolCommonFactorExtraction() { - org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.BooleanSimplification simplification = - new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.BooleanSimplification(); - - Expression a1 = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 1); - Expression a2 = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 1); - Expression b = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 2); - Expression c = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 3); - - Or actual = new Or(EMPTY, new And(EMPTY, a1, b), new And(EMPTY, a2, c)); - And expected = new And(EMPTY, a1, new Or(EMPTY, b, c)); - - assertEquals(expected, simplification.rule(actual)); - } -} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java index a418670e98eac..a99ce5d873b44 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/PhysicalPlanOptimizerTests.java @@ -48,9 +48,6 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.esql.core.index.EsIndex; import org.elasticsearch.xpack.esql.core.index.IndexResolution; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.enrich.ResolvedEnrichPolicy; @@ -79,6 +76,9 @@ import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.TopN; import org.elasticsearch.xpack.esql.plan.logical.join.Join; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/BooleanFunctionEqualsEliminationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/BooleanFunctionEqualsEliminationTests.java new file mode 100644 index 0000000000000..d5d274d0fc62f --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/BooleanFunctionEqualsEliminationTests.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; +import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; + +import java.util.List; + +import static java.util.Arrays.asList; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.notEqualsOf; +import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; +import static org.elasticsearch.xpack.esql.core.expression.Literal.NULL; +import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; + +public class BooleanFunctionEqualsEliminationTests extends ESTestCase { + + public void testBoolEqualsSimplificationOnExpressions() { + BooleanFunctionEqualsElimination s = new BooleanFunctionEqualsElimination(); + Expression exp = new GreaterThan(EMPTY, getFieldAttribute(), new Literal(EMPTY, 0, DataType.INTEGER), null); + + assertEquals(exp, s.rule(new Equals(EMPTY, exp, TRUE))); + // TODO: Replace use of QL Not with ESQL Not + assertEquals(new Not(EMPTY, exp), s.rule(new Equals(EMPTY, exp, FALSE))); + } + + public void testBoolEqualsSimplificationOnFields() { + BooleanFunctionEqualsElimination s = new BooleanFunctionEqualsElimination(); + + FieldAttribute field = getFieldAttribute(); + + List comparisons = asList( + new Equals(EMPTY, field, TRUE), + new Equals(EMPTY, field, FALSE), + notEqualsOf(field, TRUE), + notEqualsOf(field, FALSE), + new Equals(EMPTY, NULL, TRUE), + new Equals(EMPTY, NULL, FALSE), + notEqualsOf(NULL, TRUE), + notEqualsOf(NULL, FALSE) + ); + + for (BinaryComparison comparison : comparisons) { + assertEquals(comparison, s.rule(comparison)); + } + } + +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/BooleanSimplificationTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/BooleanSimplificationTests.java new file mode 100644 index 0000000000000..03cd5921a80e2 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/BooleanSimplificationTests.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; + +import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; +import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; + +public class BooleanSimplificationTests extends ESTestCase { + private static final Expression DUMMY_EXPRESSION = + new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 0); + + public void testBoolSimplifyOr() { + OptimizerRules.BooleanSimplification simplification = new OptimizerRules.BooleanSimplification(); + + assertEquals(TRUE, simplification.rule(new Or(EMPTY, TRUE, TRUE))); + assertEquals(TRUE, simplification.rule(new Or(EMPTY, TRUE, DUMMY_EXPRESSION))); + assertEquals(TRUE, simplification.rule(new Or(EMPTY, DUMMY_EXPRESSION, TRUE))); + + assertEquals(FALSE, simplification.rule(new Or(EMPTY, FALSE, FALSE))); + assertEquals(DUMMY_EXPRESSION, simplification.rule(new Or(EMPTY, FALSE, DUMMY_EXPRESSION))); + assertEquals(DUMMY_EXPRESSION, simplification.rule(new Or(EMPTY, DUMMY_EXPRESSION, FALSE))); + } + + public void testBoolSimplifyAnd() { + OptimizerRules.BooleanSimplification simplification = new OptimizerRules.BooleanSimplification(); + + assertEquals(TRUE, simplification.rule(new And(EMPTY, TRUE, TRUE))); + assertEquals(DUMMY_EXPRESSION, simplification.rule(new And(EMPTY, TRUE, DUMMY_EXPRESSION))); + assertEquals(DUMMY_EXPRESSION, simplification.rule(new And(EMPTY, DUMMY_EXPRESSION, TRUE))); + + assertEquals(FALSE, simplification.rule(new And(EMPTY, FALSE, FALSE))); + assertEquals(FALSE, simplification.rule(new And(EMPTY, FALSE, DUMMY_EXPRESSION))); + assertEquals(FALSE, simplification.rule(new And(EMPTY, DUMMY_EXPRESSION, FALSE))); + } + + public void testBoolCommonFactorExtraction() { + OptimizerRules.BooleanSimplification simplification = new OptimizerRules.BooleanSimplification(); + + Expression a1 = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 1); + Expression a2 = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 1); + Expression b = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 2); + Expression c = new org.elasticsearch.xpack.esql.core.optimizer.OptimizerRulesTests.DummyBooleanExpression(EMPTY, 3); + + Or actual = new Or(EMPTY, new And(EMPTY, a1, b), new And(EMPTY, a2, c)); + And expected = new And(EMPTY, a1, new Or(EMPTY, b, c)); + + assertEquals(expected, simplification.rule(actual)); + } + +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineDisjunctionsToInTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineDisjunctionsToInTests.java new file mode 100644 index 0000000000000..7bc2d69cb56e6 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/CombineDisjunctionsToInTests.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +import java.util.List; + +import static java.util.Arrays.asList; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.ONE; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.THREE; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TWO; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.equalsOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.relation; +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; +import static org.hamcrest.Matchers.contains; + +public class CombineDisjunctionsToInTests extends ESTestCase { + public void testTwoEqualsWithOr() { + FieldAttribute fa = getFieldAttribute(); + + Or or = new Or(EMPTY, equalsOf(fa, ONE), equalsOf(fa, TWO)); + Expression e = new CombineDisjunctionsToIn().rule(or); + assertEquals(In.class, e.getClass()); + In in = (In) e; + assertEquals(fa, in.value()); + assertThat(in.list(), contains(ONE, TWO)); + } + + public void testTwoEqualsWithSameValue() { + FieldAttribute fa = getFieldAttribute(); + + Or or = new Or(EMPTY, equalsOf(fa, ONE), equalsOf(fa, ONE)); + Expression e = new CombineDisjunctionsToIn().rule(or); + assertEquals(Equals.class, e.getClass()); + Equals eq = (Equals) e; + assertEquals(fa, eq.left()); + assertEquals(ONE, eq.right()); + } + + public void testOneEqualsOneIn() { + FieldAttribute fa = getFieldAttribute(); + + Or or = new Or(EMPTY, equalsOf(fa, ONE), new In(EMPTY, fa, List.of(TWO))); + Expression e = new CombineDisjunctionsToIn().rule(or); + assertEquals(In.class, e.getClass()); + In in = (In) e; + assertEquals(fa, in.value()); + assertThat(in.list(), contains(ONE, TWO)); + } + + public void testOneEqualsOneInWithSameValue() { + FieldAttribute fa = getFieldAttribute(); + + Or or = new Or(EMPTY, equalsOf(fa, ONE), new In(EMPTY, fa, asList(ONE, TWO))); + Expression e = new CombineDisjunctionsToIn().rule(or); + assertEquals(In.class, e.getClass()); + In in = (In) e; + assertEquals(fa, in.value()); + assertThat(in.list(), contains(ONE, TWO)); + } + + public void testSingleValueInToEquals() { + FieldAttribute fa = getFieldAttribute(); + + Equals equals = equalsOf(fa, ONE); + Or or = new Or(EMPTY, equals, new In(EMPTY, fa, List.of(ONE))); + Expression e = new CombineDisjunctionsToIn().rule(or); + assertEquals(equals, e); + } + + public void testEqualsBehindAnd() { + FieldAttribute fa = getFieldAttribute(); + + And and = new And(EMPTY, equalsOf(fa, ONE), equalsOf(fa, TWO)); + Filter dummy = new Filter(EMPTY, relation(), and); + LogicalPlan transformed = new CombineDisjunctionsToIn().apply(dummy); + assertSame(dummy, transformed); + assertEquals(and, ((Filter) transformed).condition()); + } + + public void testTwoEqualsDifferentFields() { + FieldAttribute fieldOne = getFieldAttribute("ONE"); + FieldAttribute fieldTwo = getFieldAttribute("TWO"); + + Or or = new Or(EMPTY, equalsOf(fieldOne, ONE), equalsOf(fieldTwo, TWO)); + Expression e = new CombineDisjunctionsToIn().rule(or); + assertEquals(or, e); + } + + public void testMultipleIn() { + FieldAttribute fa = getFieldAttribute(); + + Or firstOr = new Or(EMPTY, new In(EMPTY, fa, List.of(ONE)), new In(EMPTY, fa, List.of(TWO))); + Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, List.of(THREE))); + Expression e = new CombineDisjunctionsToIn().rule(secondOr); + assertEquals(In.class, e.getClass()); + In in = (In) e; + assertEquals(fa, in.value()); + assertThat(in.list(), contains(ONE, TWO, THREE)); + } + + public void testOrWithNonCombinableExpressions() { + FieldAttribute fa = getFieldAttribute(); + + Or firstOr = new Or(EMPTY, new In(EMPTY, fa, List.of(ONE)), lessThanOf(fa, TWO)); + Or secondOr = new Or(EMPTY, firstOr, new In(EMPTY, fa, List.of(THREE))); + Expression e = new CombineDisjunctionsToIn().rule(secondOr); + assertEquals(Or.class, e.getClass()); + Or or = (Or) e; + assertEquals(or.left(), firstOr.right()); + assertEquals(In.class, or.right().getClass()); + In in = (In) or.right(); + assertEquals(fa, in.value()); + assertThat(in.list(), contains(ONE, THREE)); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/ConstantFoldingTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/ConstantFoldingTests.java new file mode 100644 index 0000000000000..366116d33901f --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/ConstantFoldingTests.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Expressions; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.Nullability; +import org.elasticsearch.xpack.esql.core.expression.predicate.BinaryOperator; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.Like; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.LikePattern; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLike; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardLike; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mod; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub; + +import static org.elasticsearch.xpack.esql.EsqlTestUtils.FIVE; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.THREE; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TWO; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.equalsOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOrEqualOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOrEqualOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.notEqualsOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.of; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.rangeOf; +import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; +import static org.elasticsearch.xpack.esql.core.expression.Literal.NULL; +import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; + +public class ConstantFoldingTests extends ESTestCase { + + public void testConstantFolding() { + Expression exp = new Add(EMPTY, TWO, THREE); + + assertTrue(exp.foldable()); + Expression result = new ConstantFolding().rule(exp); + assertTrue(result instanceof Literal); + assertEquals(5, ((Literal) result).value()); + + // check now with an alias + result = new ConstantFolding().rule(new Alias(EMPTY, "a", exp)); + assertEquals("a", Expressions.name(result)); + assertEquals(Alias.class, result.getClass()); + } + + public void testConstantFoldingBinaryComparison() { + assertEquals(FALSE, new ConstantFolding().rule(greaterThanOf(TWO, THREE)).canonical()); + assertEquals(FALSE, new ConstantFolding().rule(greaterThanOrEqualOf(TWO, THREE)).canonical()); + assertEquals(FALSE, new ConstantFolding().rule(equalsOf(TWO, THREE)).canonical()); + assertEquals(TRUE, new ConstantFolding().rule(notEqualsOf(TWO, THREE)).canonical()); + assertEquals(TRUE, new ConstantFolding().rule(lessThanOrEqualOf(TWO, THREE)).canonical()); + assertEquals(TRUE, new ConstantFolding().rule(lessThanOf(TWO, THREE)).canonical()); + } + + public void testConstantFoldingBinaryLogic() { + assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, greaterThanOf(TWO, THREE), TRUE)).canonical()); + assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, greaterThanOrEqualOf(TWO, THREE), TRUE)).canonical()); + } + + public void testConstantFoldingBinaryLogic_WithNullHandling() { + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, NULL, TRUE)).canonical().nullable()); + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, TRUE, NULL)).canonical().nullable()); + assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, NULL, FALSE)).canonical()); + assertEquals(FALSE, new ConstantFolding().rule(new And(EMPTY, FALSE, NULL)).canonical()); + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new And(EMPTY, NULL, NULL)).canonical().nullable()); + + assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, TRUE)).canonical()); + assertEquals(TRUE, new ConstantFolding().rule(new Or(EMPTY, TRUE, NULL)).canonical()); + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, FALSE)).canonical().nullable()); + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, FALSE, NULL)).canonical().nullable()); + assertEquals(Nullability.TRUE, new ConstantFolding().rule(new Or(EMPTY, NULL, NULL)).canonical().nullable()); + } + + public void testConstantFoldingRange() { + assertEquals(true, new ConstantFolding().rule(rangeOf(FIVE, FIVE, true, new Literal(EMPTY, 10, DataType.INTEGER), false)).fold()); + assertEquals(false, new ConstantFolding().rule(rangeOf(FIVE, FIVE, false, new Literal(EMPTY, 10, DataType.INTEGER), false)).fold()); + } + + public void testConstantNot() { + assertEquals(FALSE, new ConstantFolding().rule(new Not(EMPTY, TRUE))); + assertEquals(TRUE, new ConstantFolding().rule(new Not(EMPTY, FALSE))); + } + + public void testConstantFoldingLikes() { + assertEquals(TRUE, new ConstantFolding().rule(new Like(EMPTY, of("test_emp"), new LikePattern("test%", (char) 0))).canonical()); + assertEquals(TRUE, new ConstantFolding().rule(new WildcardLike(EMPTY, of("test_emp"), new WildcardPattern("test*"))).canonical()); + assertEquals(TRUE, new ConstantFolding().rule(new RLike(EMPTY, of("test_emp"), new RLikePattern("test.emp"))).canonical()); + } + + public void testArithmeticFolding() { + assertEquals(10, foldOperator(new Add(EMPTY, new Literal(EMPTY, 7, DataType.INTEGER), THREE))); + assertEquals(4, foldOperator(new Sub(EMPTY, new Literal(EMPTY, 7, DataType.INTEGER), THREE))); + assertEquals(21, foldOperator(new Mul(EMPTY, new Literal(EMPTY, 7, DataType.INTEGER), THREE))); + assertEquals(2, foldOperator(new Div(EMPTY, new Literal(EMPTY, 7, DataType.INTEGER), THREE))); + assertEquals(1, foldOperator(new Mod(EMPTY, new Literal(EMPTY, 7, DataType.INTEGER), THREE))); + } + + private static Object foldOperator(BinaryOperator b) { + return ((Literal) new ConstantFolding().rule(b)).value(); + } + +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/FoldNullTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/FoldNullTests.java new file mode 100644 index 0000000000000..db5d42f8bb810 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/FoldNullTests.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; +import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNull; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLike; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; + +import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf; +import static org.elasticsearch.xpack.esql.core.expression.Literal.NULL; +import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; + +public class FoldNullTests extends ESTestCase { + + public void testNullFoldingIsNull() { + OptimizerRules.FoldNull foldNull = new OptimizerRules.FoldNull(); + assertEquals(true, foldNull.rule(new IsNull(EMPTY, NULL)).fold()); + assertEquals(false, foldNull.rule(new IsNull(EMPTY, TRUE)).fold()); + } + + public void testGenericNullableExpression() { + OptimizerRules.FoldNull rule = new OptimizerRules.FoldNull(); + // arithmetic + assertNullLiteral(rule.rule(new Add(EMPTY, getFieldAttribute(), NULL))); + // comparison + assertNullLiteral(rule.rule(greaterThanOf(getFieldAttribute(), NULL))); + // regex + assertNullLiteral(rule.rule(new RLike(EMPTY, NULL, new RLikePattern("123")))); + } + + public void testNullFoldingDoesNotApplyOnLogicalExpressions() { + OptimizerRules.FoldNull rule = new OptimizerRules.FoldNull(); + + Or or = new Or(EMPTY, NULL, TRUE); + assertEquals(or, rule.rule(or)); + or = new Or(EMPTY, NULL, NULL); + assertEquals(or, rule.rule(or)); + + And and = new And(EMPTY, NULL, TRUE); + assertEquals(and, rule.rule(and)); + and = new And(EMPTY, NULL, NULL); + assertEquals(and, rule.rule(and)); + } + + private void assertNullLiteral(Expression expression) { + assertEquals(Literal.class, expression.getClass()); + assertNull(expression.fold()); + } + +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/LiteralsOnTheRightTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/LiteralsOnTheRightTests.java new file mode 100644 index 0000000000000..a884080504db8 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/LiteralsOnTheRightTests.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.Alias; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; + +import static org.elasticsearch.xpack.esql.EsqlTestUtils.FIVE; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.equalsOf; +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; +import static org.elasticsearch.xpack.esql.core.type.DataType.INTEGER; + +public class LiteralsOnTheRightTests extends ESTestCase { + + public void testLiteralsOnTheRight() { + Alias a = new Alias(EMPTY, "a", new Literal(EMPTY, 10, INTEGER)); + Expression result = new LiteralsOnTheRight().rule(equalsOf(FIVE, a)); + assertTrue(result instanceof Equals); + Equals eq = (Equals) result; + assertEquals(a, eq.left()); + assertEquals(FIVE, eq.right()); + + // Note: Null Equals test removed here + } + +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateEqualsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateEqualsTests.java new file mode 100644 index 0000000000000..99632fa127a3b --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateEqualsTests.java @@ -0,0 +1,335 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; +import org.elasticsearch.xpack.esql.core.expression.predicate.Range; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThan; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThan; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.NotEquals; + +import static java.util.Arrays.asList; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.FIVE; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.FOUR; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.ONE; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.THREE; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TWO; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.equalsOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOrEqualOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOrEqualOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.notEqualsOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.rangeOf; +import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; +import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; + +public class PropagateEqualsTests extends ESTestCase { + + // a == 1 AND a == 2 -> FALSE + public void testDualEqualsConjunction() { + FieldAttribute fa = getFieldAttribute(); + Equals eq1 = equalsOf(fa, ONE); + Equals eq2 = equalsOf(fa, TWO); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, eq1, eq2)); + assertEquals(FALSE, exp); + } + + // 1 < a < 10 AND a == 10 -> FALSE + public void testEliminateRangeByEqualsOutsideInterval() { + FieldAttribute fa = getFieldAttribute(); + Equals eq1 = equalsOf(fa, new Literal(EMPTY, 10, DataType.INTEGER)); + Range r = rangeOf(fa, ONE, false, new Literal(EMPTY, 10, DataType.INTEGER), false); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, eq1, r)); + assertEquals(FALSE, exp); + } + + // a != 3 AND a = 3 -> FALSE + public void testPropagateEquals_VarNeq3AndVarEq3() { + FieldAttribute fa = getFieldAttribute(); + NotEquals neq = notEqualsOf(fa, THREE); + Equals eq = equalsOf(fa, THREE); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, neq, eq)); + assertEquals(FALSE, exp); + } + + // a != 4 AND a = 3 -> a = 3 + public void testPropagateEquals_VarNeq4AndVarEq3() { + FieldAttribute fa = getFieldAttribute(); + NotEquals neq = notEqualsOf(fa, FOUR); + Equals eq = equalsOf(fa, THREE); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, neq, eq)); + assertEquals(Equals.class, exp.getClass()); + assertEquals(eq, exp); + } + + // a = 2 AND a < 2 -> FALSE + public void testPropagateEquals_VarEq2AndVarLt2() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, TWO); + LessThan lt = lessThanOf(fa, TWO); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, eq, lt)); + assertEquals(FALSE, exp); + } + + // a = 2 AND a <= 2 -> a = 2 + public void testPropagateEquals_VarEq2AndVarLte2() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, TWO); + LessThanOrEqual lt = lessThanOrEqualOf(fa, TWO); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, eq, lt)); + assertEquals(eq, exp); + } + + // a = 2 AND a <= 1 -> FALSE + public void testPropagateEquals_VarEq2AndVarLte1() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, TWO); + LessThanOrEqual lt = lessThanOrEqualOf(fa, ONE); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, eq, lt)); + assertEquals(FALSE, exp); + } + + // a = 2 AND a > 2 -> FALSE + public void testPropagateEquals_VarEq2AndVarGt2() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, TWO); + GreaterThan gt = greaterThanOf(fa, TWO); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, eq, gt)); + assertEquals(FALSE, exp); + } + + // a = 2 AND a >= 2 -> a = 2 + public void testPropagateEquals_VarEq2AndVarGte2() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, TWO); + GreaterThanOrEqual gte = greaterThanOrEqualOf(fa, TWO); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, eq, gte)); + assertEquals(eq, exp); + } + + // a = 2 AND a > 3 -> FALSE + public void testPropagateEquals_VarEq2AndVarLt3() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, TWO); + GreaterThan gt = greaterThanOf(fa, THREE); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, eq, gt)); + assertEquals(FALSE, exp); + } + + // a = 2 AND a < 3 AND a > 1 AND a != 4 -> a = 2 + public void testPropagateEquals_VarEq2AndVarLt3AndVarGt1AndVarNeq4() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, TWO); + LessThan lt = lessThanOf(fa, THREE); + GreaterThan gt = greaterThanOf(fa, ONE); + NotEquals neq = notEqualsOf(fa, FOUR); + + PropagateEquals rule = new PropagateEquals(); + Expression and = Predicates.combineAnd(asList(eq, lt, gt, neq)); + Expression exp = rule.rule((And) and); + assertEquals(eq, exp); + } + + // a = 2 AND 1 < a < 3 AND a > 0 AND a != 4 -> a = 2 + public void testPropagateEquals_VarEq2AndVarRangeGt1Lt3AndVarGt0AndVarNeq4() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, TWO); + Range range = rangeOf(fa, ONE, false, THREE, false); + GreaterThan gt = greaterThanOf(fa, new Literal(EMPTY, 0, DataType.INTEGER)); + NotEquals neq = notEqualsOf(fa, FOUR); + + PropagateEquals rule = new PropagateEquals(); + Expression and = Predicates.combineAnd(asList(eq, range, gt, neq)); + Expression exp = rule.rule((And) and); + assertEquals(eq, exp); + } + + // a = 2 OR a > 1 -> a > 1 + public void testPropagateEquals_VarEq2OrVarGt1() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, TWO); + GreaterThan gt = greaterThanOf(fa, ONE); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, gt)); + assertEquals(gt, exp); + } + + // a = 2 OR a > 2 -> a >= 2 + public void testPropagateEquals_VarEq2OrVarGte2() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, TWO); + GreaterThan gt = greaterThanOf(fa, TWO); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, gt)); + assertEquals(GreaterThanOrEqual.class, exp.getClass()); + GreaterThanOrEqual gte = (GreaterThanOrEqual) exp; + assertEquals(TWO, gte.right()); + } + + // a = 2 OR a < 3 -> a < 3 + public void testPropagateEquals_VarEq2OrVarLt3() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, TWO); + LessThan lt = lessThanOf(fa, THREE); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, lt)); + assertEquals(lt, exp); + } + + // a = 3 OR a < 3 -> a <= 3 + public void testPropagateEquals_VarEq3OrVarLt3() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, THREE); + LessThan lt = lessThanOf(fa, THREE); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, lt)); + assertEquals(LessThanOrEqual.class, exp.getClass()); + LessThanOrEqual lte = (LessThanOrEqual) exp; + assertEquals(THREE, lte.right()); + } + + // a = 2 OR 1 < a < 3 -> 1 < a < 3 + public void testPropagateEquals_VarEq2OrVarRangeGt1Lt3() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, TWO); + Range range = rangeOf(fa, ONE, false, THREE, false); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, range)); + assertEquals(range, exp); + } + + // a = 2 OR 2 < a < 3 -> 2 <= a < 3 + public void testPropagateEquals_VarEq2OrVarRangeGt2Lt3() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, TWO); + Range range = rangeOf(fa, TWO, false, THREE, false); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, range)); + assertEquals(Range.class, exp.getClass()); + Range r = (Range) exp; + assertEquals(TWO, r.lower()); + assertTrue(r.includeLower()); + assertEquals(THREE, r.upper()); + assertFalse(r.includeUpper()); + } + + // a = 3 OR 2 < a < 3 -> 2 < a <= 3 + public void testPropagateEquals_VarEq3OrVarRangeGt2Lt3() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, THREE); + Range range = rangeOf(fa, TWO, false, THREE, false); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, range)); + assertEquals(Range.class, exp.getClass()); + Range r = (Range) exp; + assertEquals(TWO, r.lower()); + assertFalse(r.includeLower()); + assertEquals(THREE, r.upper()); + assertTrue(r.includeUpper()); + } + + // a = 2 OR a != 2 -> TRUE + public void testPropagateEquals_VarEq2OrVarNeq2() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, TWO); + NotEquals neq = notEqualsOf(fa, TWO); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, neq)); + assertEquals(TRUE, exp); + } + + // a = 2 OR a != 5 -> a != 5 + public void testPropagateEquals_VarEq2OrVarNeq5() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, TWO); + NotEquals neq = notEqualsOf(fa, FIVE); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new Or(EMPTY, eq, neq)); + assertEquals(NotEquals.class, exp.getClass()); + NotEquals ne = (NotEquals) exp; + assertEquals(FIVE, ne.right()); + } + + // a = 2 OR 3 < a < 4 OR a > 2 OR a!= 2 -> TRUE + public void testPropagateEquals_VarEq2OrVarRangeGt3Lt4OrVarGt2OrVarNe2() { + FieldAttribute fa = getFieldAttribute(); + Equals eq = equalsOf(fa, TWO); + Range range = rangeOf(fa, THREE, false, FOUR, false); + GreaterThan gt = greaterThanOf(fa, TWO); + NotEquals neq = notEqualsOf(fa, TWO); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule((Or) Predicates.combineOr(asList(eq, range, neq, gt))); + assertEquals(TRUE, exp); + } + + // a == 1 AND a == 2 -> nop for date/time fields + public void testPropagateEquals_ignoreDateTimeFields() { + FieldAttribute fa = getFieldAttribute("a", DataType.DATETIME); + Equals eq1 = equalsOf(fa, ONE); + Equals eq2 = equalsOf(fa, TWO); + And and = new And(EMPTY, eq1, eq2); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(and); + assertEquals(and, exp); + } + + // 1 <= a < 10 AND a == 1 -> a == 1 + public void testEliminateRangeByEqualsInInterval() { + FieldAttribute fa = getFieldAttribute(); + Equals eq1 = equalsOf(fa, ONE); + Range r = rangeOf(fa, ONE, true, new Literal(EMPTY, 10, DataType.INTEGER), false); + + PropagateEquals rule = new PropagateEquals(); + Expression exp = rule.rule(new And(EMPTY, eq1, r)); + assertEquals(eq1, exp); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateNullableTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateNullableTests.java new file mode 100644 index 0000000000000..23c0886f1a7d3 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/PropagateNullableTests.java @@ -0,0 +1,134 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.predicate.Predicates; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; +import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; +import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; +import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNull; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Div; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; + +import static org.elasticsearch.xpack.esql.EsqlTestUtils.ONE; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.THREE; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.TWO; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.greaterThanOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.lessThanOf; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.relation; +import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; +import static org.elasticsearch.xpack.esql.core.type.DataType.BOOLEAN; + +public class PropagateNullableTests extends ESTestCase { + private Literal nullOf(DataType dataType) { + return new Literal(Source.EMPTY, null, dataType); + } + + // a IS NULL AND a IS NOT NULL => false + public void testIsNullAndNotNull() { + FieldAttribute fa = getFieldAttribute(); + + And and = new And(EMPTY, new IsNull(EMPTY, fa), new IsNotNull(EMPTY, fa)); + assertEquals(FALSE, new OptimizerRules.PropagateNullable().rule(and)); + } + + // a IS NULL AND b IS NOT NULL AND c IS NULL AND d IS NOT NULL AND e IS NULL AND a IS NOT NULL => false + public void testIsNullAndNotNullMultiField() { + FieldAttribute fa = getFieldAttribute(); + + And andOne = new And(EMPTY, new IsNull(EMPTY, fa), new IsNotNull(EMPTY, getFieldAttribute())); + And andTwo = new And(EMPTY, new IsNull(EMPTY, getFieldAttribute()), new IsNotNull(EMPTY, getFieldAttribute())); + And andThree = new And(EMPTY, new IsNull(EMPTY, getFieldAttribute()), new IsNotNull(EMPTY, fa)); + + And and = new And(EMPTY, andOne, new And(EMPTY, andThree, andTwo)); + + assertEquals(FALSE, new OptimizerRules.PropagateNullable().rule(and)); + } + + // a IS NULL AND a > 1 => a IS NULL AND false + public void testIsNullAndComparison() { + FieldAttribute fa = getFieldAttribute(); + IsNull isNull = new IsNull(EMPTY, fa); + + And and = new And(EMPTY, isNull, greaterThanOf(fa, ONE)); + assertEquals(new And(EMPTY, isNull, nullOf(BOOLEAN)), new OptimizerRules.PropagateNullable().rule(and)); + } + + // a IS NULL AND b < 1 AND c < 1 AND a < 1 => a IS NULL AND b < 1 AND c < 1 => a IS NULL AND b < 1 AND c < 1 + public void testIsNullAndMultipleComparison() { + FieldAttribute fa = getFieldAttribute(); + IsNull isNull = new IsNull(EMPTY, fa); + + And nestedAnd = new And(EMPTY, lessThanOf(getFieldAttribute("b"), ONE), lessThanOf(getFieldAttribute("c"), ONE)); + And and = new And(EMPTY, isNull, nestedAnd); + And top = new And(EMPTY, and, lessThanOf(fa, ONE)); + + Expression optimized = new OptimizerRules.PropagateNullable().rule(top); + Expression expected = new And(EMPTY, and, nullOf(BOOLEAN)); + assertEquals(Predicates.splitAnd(expected), Predicates.splitAnd(optimized)); + } + + // ((a+1)/2) > 1 AND a + 2 AND a IS NULL AND b < 3 => NULL AND NULL AND a IS NULL AND b < 3 + public void testIsNullAndDeeplyNestedExpression() { + FieldAttribute fa = getFieldAttribute(); + IsNull isNull = new IsNull(EMPTY, fa); + + Expression nullified = new And( + EMPTY, + greaterThanOf(new Div(EMPTY, new Add(EMPTY, fa, ONE), TWO), ONE), + greaterThanOf(new Add(EMPTY, fa, TWO), ONE) + ); + Expression kept = new And(EMPTY, isNull, lessThanOf(getFieldAttribute("b"), THREE)); + And and = new And(EMPTY, nullified, kept); + + Expression optimized = new OptimizerRules.PropagateNullable().rule(and); + Expression expected = new And(EMPTY, new And(EMPTY, nullOf(BOOLEAN), nullOf(BOOLEAN)), kept); + + assertEquals(Predicates.splitAnd(expected), Predicates.splitAnd(optimized)); + } + + // a IS NULL OR a IS NOT NULL => no change + // a IS NULL OR a > 1 => no change + public void testIsNullInDisjunction() { + FieldAttribute fa = getFieldAttribute(); + + Or or = new Or(EMPTY, new IsNull(EMPTY, fa), new IsNotNull(EMPTY, fa)); + Filter dummy = new Filter(EMPTY, relation(), or); + LogicalPlan transformed = new OptimizerRules.PropagateNullable().apply(dummy); + assertSame(dummy, transformed); + assertEquals(or, ((Filter) transformed).condition()); + + or = new Or(EMPTY, new IsNull(EMPTY, fa), greaterThanOf(fa, ONE)); + dummy = new Filter(EMPTY, relation(), or); + transformed = new OptimizerRules.PropagateNullable().apply(dummy); + assertSame(dummy, transformed); + assertEquals(or, ((Filter) transformed).condition()); + } + + // a + 1 AND (a IS NULL OR a > 3) => no change + public void testIsNullDisjunction() { + FieldAttribute fa = getFieldAttribute(); + IsNull isNull = new IsNull(EMPTY, fa); + + Or or = new Or(EMPTY, isNull, greaterThanOf(fa, THREE)); + And and = new And(EMPTY, new Add(EMPTY, fa, ONE), or); + + assertEquals(and, new OptimizerRules.PropagateNullable().rule(and)); + } + +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceRegexMatchTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceRegexMatchTests.java new file mode 100644 index 0000000000000..62b13e6c9cc03 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/rules/ReplaceRegexMatchTests.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.Like; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.LikePattern; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLike; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.RLikePattern; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardLike; +import org.elasticsearch.xpack.esql.core.expression.predicate.regex.WildcardPattern; +import org.elasticsearch.xpack.esql.core.util.StringUtils; +import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; + +import static java.util.Arrays.asList; +import static org.elasticsearch.xpack.esql.EsqlTestUtils.getFieldAttribute; +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; + +public class ReplaceRegexMatchTests extends ESTestCase { + + public void testMatchAllLikeToExist() { + for (String s : asList("%", "%%", "%%%")) { + LikePattern pattern = new LikePattern(s, (char) 0); + FieldAttribute fa = getFieldAttribute(); + Like l = new Like(EMPTY, fa, pattern); + Expression e = new ReplaceRegexMatch().rule(l); + assertEquals(IsNotNull.class, e.getClass()); + IsNotNull inn = (IsNotNull) e; + assertEquals(fa, inn.field()); + } + } + + public void testMatchAllWildcardLikeToExist() { + for (String s : asList("*", "**", "***")) { + WildcardPattern pattern = new WildcardPattern(s); + FieldAttribute fa = getFieldAttribute(); + WildcardLike l = new WildcardLike(EMPTY, fa, pattern); + Expression e = new ReplaceRegexMatch().rule(l); + assertEquals(IsNotNull.class, e.getClass()); + IsNotNull inn = (IsNotNull) e; + assertEquals(fa, inn.field()); + } + } + + public void testMatchAllRLikeToExist() { + RLikePattern pattern = new RLikePattern(".*"); + FieldAttribute fa = getFieldAttribute(); + RLike l = new RLike(EMPTY, fa, pattern); + Expression e = new ReplaceRegexMatch().rule(l); + assertEquals(IsNotNull.class, e.getClass()); + IsNotNull inn = (IsNotNull) e; + assertEquals(fa, inn.field()); + } + + public void testExactMatchLike() { + for (String s : asList("ab", "ab0%", "ab0_c")) { + LikePattern pattern = new LikePattern(s, '0'); + FieldAttribute fa = getFieldAttribute(); + Like l = new Like(EMPTY, fa, pattern); + Expression e = new ReplaceRegexMatch().rule(l); + assertEquals(Equals.class, e.getClass()); + Equals eq = (Equals) e; + assertEquals(fa, eq.left()); + assertEquals(s.replace("0", StringUtils.EMPTY), eq.right().fold()); + } + } + + public void testExactMatchWildcardLike() { + String s = "ab"; + WildcardPattern pattern = new WildcardPattern(s); + FieldAttribute fa = getFieldAttribute(); + WildcardLike l = new WildcardLike(EMPTY, fa, pattern); + Expression e = new ReplaceRegexMatch().rule(l); + assertEquals(Equals.class, e.getClass()); + Equals eq = (Equals) e; + assertEquals(fa, eq.left()); + assertEquals(s, eq.right().fold()); + } + + public void testExactMatchRLike() { + RLikePattern pattern = new RLikePattern("abc"); + FieldAttribute fa = getFieldAttribute(); + RLike l = new RLike(EMPTY, fa, pattern); + Expression e = new ReplaceRegexMatch().rule(l); + assertEquals(Equals.class, e.getClass()); + Equals eq = (Equals) e; + assertEquals(fa, eq.left()); + assertEquals("abc", eq.right().fold()); + } + +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/AbstractStatementParserTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/AbstractStatementParserTests.java index 545f3efe8ca79..d575ba1fcb55a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/AbstractStatementParserTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/AbstractStatementParserTests.java @@ -12,8 +12,8 @@ import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; import org.elasticsearch.xpack.esql.core.expression.UnresolvedAttribute; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import java.math.BigInteger; import java.util.ArrayList; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java index ac89298ffcfbb..80a2d49d0d94a 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/ExpressionTests.java @@ -16,8 +16,6 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.logical.And; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.expression.UnresolvedNamePattern; import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction; @@ -31,6 +29,8 @@ import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.GreaterThanOrEqual; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.LessThanOrEqual; import org.elasticsearch.xpack.esql.plan.logical.Drop; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.Rename; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java index 2e2ca4feafa41..2f76cb2049820 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/parser/StatementParserTests.java @@ -21,10 +21,6 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Not; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.esql.core.plan.TableIdentifier; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.expression.function.UnresolvedFunction; import org.elasticsearch.xpack.esql.expression.function.scalar.string.RLike; import org.elasticsearch.xpack.esql.expression.function.scalar.string.WildcardLike; @@ -41,10 +37,14 @@ import org.elasticsearch.xpack.esql.plan.logical.EsqlUnresolvedRelation; import org.elasticsearch.xpack.esql.plan.logical.Eval; import org.elasticsearch.xpack.esql.plan.logical.Explain; +import org.elasticsearch.xpack.esql.plan.logical.Filter; import org.elasticsearch.xpack.esql.plan.logical.Grok; import org.elasticsearch.xpack.esql.plan.logical.InlineStats; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.Lookup; import org.elasticsearch.xpack.esql.plan.logical.MvExpand; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Project; import org.elasticsearch.xpack.esql.plan.logical.Row; @@ -758,15 +758,27 @@ public void testDissectPattern() { public void testGrokPattern() { LogicalPlan cmd = processingCommand("grok a \"%{WORD:foo}\""); assertEquals(Grok.class, cmd.getClass()); - Grok dissect = (Grok) cmd; - assertEquals("%{WORD:foo}", dissect.parser().pattern()); - assertEquals(List.of(referenceAttribute("foo", KEYWORD)), dissect.extractedFields()); + Grok grok = (Grok) cmd; + assertEquals("%{WORD:foo}", grok.parser().pattern()); + assertEquals(List.of(referenceAttribute("foo", KEYWORD)), grok.extractedFields()); ParsingException pe = expectThrows(ParsingException.class, () -> statement("row a = \"foo bar\" | grok a \"%{_invalid_:x}\"")); assertThat( pe.getMessage(), containsString("Invalid pattern [%{_invalid_:x}] for grok: Unable to find pattern [_invalid_] in Grok's pattern dictionary") ); + + cmd = processingCommand("grok a \"%{WORD:foo} %{WORD:foo}\""); + assertEquals(Grok.class, cmd.getClass()); + grok = (Grok) cmd; + assertEquals("%{WORD:foo} %{WORD:foo}", grok.parser().pattern()); + assertEquals(List.of(referenceAttribute("foo", KEYWORD)), grok.extractedFields()); + + expectError( + "row a = \"foo bar\" | GROK a \"%{NUMBER:foo} %{WORD:foo}\"", + "line 1:22: Invalid GROK pattern [%{NUMBER:foo} %{WORD:foo}]:" + + " the attribute [foo] is defined multiple times with different types" + ); } public void testLikeRLike() { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java index a62a515ee551b..a254207865ad5 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java @@ -13,11 +13,11 @@ import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.plan.logical.Filter; -import org.elasticsearch.xpack.esql.core.plan.logical.Limit; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; -import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; +import org.elasticsearch.xpack.esql.plan.logical.Filter; +import org.elasticsearch.xpack.esql.plan.logical.Limit; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; +import org.elasticsearch.xpack.esql.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.plan.logical.Project; import java.util.ArrayList; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ComputeListenerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ComputeListenerTests.java new file mode 100644 index 0000000000000..c93f3b9e0e350 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/ComputeListenerTests.java @@ -0,0 +1,246 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.plugin; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRunnable; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.cluster.node.VersionInformation; +import org.elasticsearch.common.breaker.CircuitBreaker; +import org.elasticsearch.common.breaker.CircuitBreakingException; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.compute.operator.DriverProfile; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.tasks.CancellableTask; +import org.elasticsearch.tasks.TaskCancellationService; +import org.elasticsearch.tasks.TaskCancelledException; +import org.elasticsearch.tasks.TaskId; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.TransportVersionUtils; +import org.elasticsearch.test.transport.MockTransportService; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.transport.TransportService; +import org.junit.After; +import org.junit.Before; +import org.mockito.Mockito; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.stream.Collectors; +import java.util.stream.IntStream; + +import static org.elasticsearch.test.tasks.MockTaskManager.SPY_TASK_MANAGER_SETTING; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.lessThan; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; + +public class ComputeListenerTests extends ESTestCase { + private ThreadPool threadPool; + private TransportService transportService; + + @Before + public void setUpTransportService() { + threadPool = new TestThreadPool(getTestName()); + transportService = MockTransportService.createNewService( + Settings.builder().put(SPY_TASK_MANAGER_SETTING.getKey(), true).build(), + VersionInformation.CURRENT, + TransportVersionUtils.randomVersion(), + threadPool + ); + transportService.start(); + TaskCancellationService cancellationService = new TaskCancellationService(transportService); + transportService.getTaskManager().setTaskCancellationService(cancellationService); + Mockito.clearInvocations(transportService.getTaskManager()); + } + + @After + public void shutdownTransportService() { + transportService.close(); + terminate(threadPool); + } + + private CancellableTask newTask() { + return new CancellableTask( + randomIntBetween(1, 100), + "test-type", + "test-action", + "test-description", + TaskId.EMPTY_TASK_ID, + Map.of() + ); + } + + private ComputeResponse randomResponse() { + int numProfiles = randomIntBetween(0, 2); + List profiles = new ArrayList<>(numProfiles); + for (int i = 0; i < numProfiles; i++) { + profiles.add(new DriverProfile(randomNonNegativeLong(), randomNonNegativeLong(), randomNonNegativeLong(), List.of())); + } + return new ComputeResponse(profiles); + } + + public void testEmpty() { + PlainActionFuture results = new PlainActionFuture<>(); + try (ComputeListener ignored = new ComputeListener(transportService, newTask(), results)) { + assertFalse(results.isDone()); + } + assertTrue(results.isDone()); + assertThat(results.actionGet(10, TimeUnit.SECONDS).getProfiles(), empty()); + } + + public void testCollectComputeResults() { + PlainActionFuture future = new PlainActionFuture<>(); + List allProfiles = new ArrayList<>(); + try (ComputeListener computeListener = new ComputeListener(transportService, newTask(), future)) { + int tasks = randomIntBetween(1, 100); + for (int t = 0; t < tasks; t++) { + if (randomBoolean()) { + ActionListener subListener = computeListener.acquireAvoid(); + threadPool.schedule( + ActionRunnable.wrap(subListener, l -> l.onResponse(null)), + TimeValue.timeValueNanos(between(0, 100)), + threadPool.generic() + ); + } else { + ComputeResponse resp = randomResponse(); + allProfiles.addAll(resp.getProfiles()); + ActionListener subListener = computeListener.acquireCompute(); + threadPool.schedule( + ActionRunnable.wrap(subListener, l -> l.onResponse(resp)), + TimeValue.timeValueNanos(between(0, 100)), + threadPool.generic() + ); + } + } + } + ComputeResponse result = future.actionGet(10, TimeUnit.SECONDS); + assertThat( + result.getProfiles().stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum)), + equalTo(allProfiles.stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum))) + ); + Mockito.verifyNoInteractions(transportService.getTaskManager()); + } + + public void testCancelOnFailure() throws Exception { + Queue rootCauseExceptions = ConcurrentCollections.newQueue(); + IntStream.range(0, between(1, 100)) + .forEach( + n -> rootCauseExceptions.add(new CircuitBreakingException("breaking exception " + n, CircuitBreaker.Durability.TRANSIENT)) + ); + int successTasks = between(1, 50); + int failedTasks = between(1, 100); + PlainActionFuture rootListener = new PlainActionFuture<>(); + CancellableTask rootTask = newTask(); + try (ComputeListener computeListener = new ComputeListener(transportService, rootTask, rootListener)) { + for (int i = 0; i < successTasks; i++) { + ActionListener subListener = computeListener.acquireCompute(); + threadPool.schedule( + ActionRunnable.wrap(subListener, l -> l.onResponse(randomResponse())), + TimeValue.timeValueNanos(between(0, 100)), + threadPool.generic() + ); + } + for (int i = 0; i < failedTasks; i++) { + ActionListener subListener = randomBoolean() ? computeListener.acquireAvoid() : computeListener.acquireCompute(); + threadPool.schedule(ActionRunnable.wrap(subListener, l -> { + Exception ex = rootCauseExceptions.poll(); + if (ex == null) { + ex = new TaskCancelledException("task was cancelled"); + } + l.onFailure(ex); + }), TimeValue.timeValueNanos(between(0, 100)), threadPool.generic()); + } + } + assertBusy(rootListener::isDone); + ExecutionException failure = expectThrows(ExecutionException.class, () -> rootListener.get(1, TimeUnit.SECONDS)); + Throwable cause = failure.getCause(); + assertNotNull(failure); + assertThat(cause, instanceOf(CircuitBreakingException.class)); + assertThat(failure.getSuppressed().length, lessThan(10)); + Mockito.verify(transportService.getTaskManager(), Mockito.times(1)) + .cancelTaskAndDescendants(eq(rootTask), eq("cancelled on failure"), eq(false), any()); + } + + public void testCollectWarnings() throws Exception { + List allProfiles = new ArrayList<>(); + Map> allWarnings = new HashMap<>(); + ActionListener rootListener = new ActionListener<>() { + @Override + public void onResponse(ComputeResponse result) { + assertThat( + result.getProfiles().stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum)), + equalTo(allProfiles.stream().collect(Collectors.toMap(p -> p, p -> 1, Integer::sum))) + ); + Map> responseHeaders = threadPool.getThreadContext() + .getResponseHeaders() + .entrySet() + .stream() + .collect(Collectors.toMap(Map.Entry::getKey, e -> new HashSet<>(e.getValue()))); + assertThat(responseHeaders, equalTo(allWarnings)); + } + + @Override + public void onFailure(Exception e) { + throw new AssertionError(e); + } + }; + CountDownLatch latch = new CountDownLatch(1); + try ( + ComputeListener computeListener = new ComputeListener( + transportService, + newTask(), + ActionListener.runAfter(rootListener, latch::countDown) + ) + ) { + int tasks = randomIntBetween(1, 100); + for (int t = 0; t < tasks; t++) { + if (randomBoolean()) { + ActionListener subListener = computeListener.acquireAvoid(); + threadPool.schedule( + ActionRunnable.wrap(subListener, l -> l.onResponse(null)), + TimeValue.timeValueNanos(between(0, 100)), + threadPool.generic() + ); + } else { + ComputeResponse resp = randomResponse(); + allProfiles.addAll(resp.getProfiles()); + int numWarnings = randomIntBetween(1, 5); + Map warnings = new HashMap<>(); + for (int i = 0; i < numWarnings; i++) { + warnings.put("key" + between(1, 10), "value" + between(1, 10)); + } + for (Map.Entry e : warnings.entrySet()) { + allWarnings.computeIfAbsent(e.getKey(), v -> new HashSet<>()).add(e.getValue()); + } + ActionListener subListener = computeListener.acquireCompute(); + threadPool.schedule(ActionRunnable.wrap(subListener, l -> { + for (Map.Entry e : warnings.entrySet()) { + threadPool.getThreadContext().addResponseHeader(e.getKey(), e.getValue()); + } + l.onResponse(resp); + }), TimeValue.timeValueNanos(between(0, 100)), threadPool.generic()); + } + } + } + assertTrue(latch.await(10, TimeUnit.SECONDS)); + Mockito.verifyNoInteractions(transportService.getTaskManager()); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java index 7454b25377594..06c6b5de3cdea 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java @@ -21,7 +21,6 @@ import org.elasticsearch.xpack.esql.analysis.AnalyzerContext; import org.elasticsearch.xpack.esql.core.index.EsIndex; import org.elasticsearch.xpack.esql.core.index.IndexResolution; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.expression.function.EsqlFunctionRegistry; import org.elasticsearch.xpack.esql.optimizer.LogicalOptimizerContext; @@ -29,6 +28,7 @@ import org.elasticsearch.xpack.esql.optimizer.PhysicalOptimizerContext; import org.elasticsearch.xpack.esql.optimizer.PhysicalPlanOptimizer; import org.elasticsearch.xpack.esql.parser.EsqlParser; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.planner.Mapper; import org.elasticsearch.xpack.esql.session.EsqlConfigurationSerializationTests; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/PlanExecutorMetricsTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/PlanExecutorMetricsTests.java index 5883d41f32125..427c30311df0b 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/PlanExecutorMetricsTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/stats/PlanExecutorMetricsTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.xpack.esql.execution.PlanExecutor; import org.elasticsearch.xpack.esql.plan.physical.PhysicalPlan; import org.elasticsearch.xpack.esql.session.IndexResolver; +import org.elasticsearch.xpack.esql.session.Result; import org.elasticsearch.xpack.esql.type.EsqlDataTypeRegistry; import org.junit.After; import org.junit.Before; @@ -33,6 +34,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.function.BiConsumer; import static org.elasticsearch.xpack.esql.EsqlTestUtils.withDefaultLimitWarning; import static org.hamcrest.Matchers.instanceOf; @@ -100,9 +102,10 @@ public void testFailedMetric() { var request = new EsqlQueryRequest(); // test a failed query: xyz field doesn't exist request.query("from test | stats m = max(xyz)"); - planExecutor.esql(request, randomAlphaOfLength(10), EsqlTestUtils.TEST_CFG, enrichResolver, new ActionListener<>() { + BiConsumer> runPhase = (p, r) -> fail("this shouldn't happen"); + planExecutor.esql(request, randomAlphaOfLength(10), EsqlTestUtils.TEST_CFG, enrichResolver, runPhase, new ActionListener<>() { @Override - public void onResponse(PhysicalPlan physicalPlan) { + public void onResponse(Result result) { fail("this shouldn't happen"); } @@ -119,9 +122,10 @@ public void onFailure(Exception e) { // fix the failing query: foo field does exist request.query("from test | stats m = max(foo)"); - planExecutor.esql(request, randomAlphaOfLength(10), EsqlTestUtils.TEST_CFG, enrichResolver, new ActionListener<>() { + runPhase = (p, r) -> r.onResponse(null); + planExecutor.esql(request, randomAlphaOfLength(10), EsqlTestUtils.TEST_CFG, enrichResolver, runPhase, new ActionListener<>() { @Override - public void onResponse(PhysicalPlan physicalPlan) {} + public void onResponse(Result result) {} @Override public void onFailure(Exception e) { diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java index 50fe272caa076..fa20cfdec0ca0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/tree/EsqlNodeSubclassTests.java @@ -28,7 +28,6 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.fulltext.FullTextPredicate; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.Like; import org.elasticsearch.xpack.esql.core.expression.predicate.regex.LikePattern; -import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.tree.AbstractNodeTestCase; import org.elasticsearch.xpack.esql.core.tree.Node; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; @@ -44,6 +43,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.string.Concat; import org.elasticsearch.xpack.esql.plan.logical.Dissect; import org.elasticsearch.xpack.esql.plan.logical.Grok; +import org.elasticsearch.xpack.esql.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.plan.logical.join.JoinType; import org.elasticsearch.xpack.esql.plan.physical.EsQueryExec; import org.elasticsearch.xpack.esql.plan.physical.EsStatsQueryExec.Stat; diff --git a/x-pack/plugin/inference/build.gradle b/x-pack/plugin/inference/build.gradle index 41ca9966c1336..beeec94f21ebf 100644 --- a/x-pack/plugin/inference/build.gradle +++ b/x-pack/plugin/inference/build.gradle @@ -27,6 +27,10 @@ base { archivesName = 'x-pack-inference' } +versions << [ + 'awsbedrockruntime': '1.12.740' +] + dependencies { implementation project(path: ':libs:elasticsearch-logging') compileOnly project(":server") @@ -53,10 +57,19 @@ dependencies { implementation 'com.google.http-client:google-http-client-appengine:1.42.3' implementation 'com.google.http-client:google-http-client-jackson2:1.42.3' implementation "com.fasterxml.jackson.core:jackson-core:${versions.jackson}" + implementation "com.fasterxml.jackson.core:jackson-databind:${versions.jackson}" + implementation "com.fasterxml.jackson.core:jackson-annotations:${versions.jackson}" + implementation "com.fasterxml.jackson.dataformat:jackson-dataformat-cbor:${versions.jackson}" + implementation "com.fasterxml.jackson:jackson-bom:${versions.jackson}" implementation 'com.google.api:gax-httpjson:0.105.1' implementation 'io.grpc:grpc-context:1.49.2' implementation 'io.opencensus:opencensus-api:0.31.1' implementation 'io.opencensus:opencensus-contrib-http-util:0.31.1' + implementation "com.amazonaws:aws-java-sdk-bedrockruntime:${versions.awsbedrockruntime}" + implementation "com.amazonaws:aws-java-sdk-core:${versions.aws}" + implementation "com.amazonaws:jmespath-java:${versions.aws}" + implementation "joda-time:joda-time:2.10.10" + implementation 'javax.xml.bind:jaxb-api:2.2.2' } tasks.named("dependencyLicenses").configure { @@ -66,6 +79,9 @@ tasks.named("dependencyLicenses").configure { mapping from: /protobuf.*/, to: 'protobuf' mapping from: /proto-google.*/, to: 'proto-google' mapping from: /jackson.*/, to: 'jackson' + mapping from: /aws-java-sdk-.*/, to: 'aws-java-sdk' + mapping from: /jmespath-java.*/, to: 'aws-java-sdk' + mapping from: /jaxb-.*/, to: 'jaxb' } tasks.named("thirdPartyAudit").configure { @@ -199,10 +215,21 @@ tasks.named("thirdPartyAudit").configure { 'com.google.appengine.api.urlfetch.HTTPRequest', 'com.google.appengine.api.urlfetch.HTTPResponse', 'com.google.appengine.api.urlfetch.URLFetchService', - 'com.google.appengine.api.urlfetch.URLFetchServiceFactory' + 'com.google.appengine.api.urlfetch.URLFetchServiceFactory', + 'software.amazon.ion.IonReader', + 'software.amazon.ion.IonSystem', + 'software.amazon.ion.IonType', + 'software.amazon.ion.IonWriter', + 'software.amazon.ion.Timestamp', + 'software.amazon.ion.system.IonBinaryWriterBuilder', + 'software.amazon.ion.system.IonSystemBuilder', + 'software.amazon.ion.system.IonTextWriterBuilder', + 'software.amazon.ion.system.IonWriterBuilder', + 'javax.activation.DataHandler' ) } tasks.named('yamlRestTest') { usesDefaultDistribution() } + diff --git a/x-pack/plugin/inference/licenses/aws-java-sdk-LICENSE.txt b/x-pack/plugin/inference/licenses/aws-java-sdk-LICENSE.txt new file mode 100644 index 0000000000000..98d1f9319f374 --- /dev/null +++ b/x-pack/plugin/inference/licenses/aws-java-sdk-LICENSE.txt @@ -0,0 +1,63 @@ +Apache License +Version 2.0, January 2004 + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + + 1. You must give any other recipients of the Work or Derivative Works a copy of this License; and + 2. You must cause any modified files to carry prominent notices stating that You changed the files; and + 3. You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + 4. If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS + +Note: Other license terms may apply to certain, identified software files contained within or distributed with the accompanying software if such terms are included in the directory containing the accompanying software. Such other license terms will then apply in lieu of the terms of the software license above. + +JSON processing code subject to the JSON License from JSON.org: + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +The Software shall be used for Good, not Evil. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/x-pack/plugin/inference/licenses/aws-java-sdk-NOTICE.txt b/x-pack/plugin/inference/licenses/aws-java-sdk-NOTICE.txt new file mode 100644 index 0000000000000..565bd6085c71a --- /dev/null +++ b/x-pack/plugin/inference/licenses/aws-java-sdk-NOTICE.txt @@ -0,0 +1,15 @@ +AWS SDK for Java +Copyright 2010-2014 Amazon.com, Inc. or its affiliates. All Rights Reserved. + +This product includes software developed by +Amazon Technologies, Inc (http://www.amazon.com/). + +********************** +THIRD PARTY COMPONENTS +********************** +This software includes third party software subject to the following copyrights: +- XML parsing and utility functions from JetS3t - Copyright 2006-2009 James Murty. +- JSON parsing and utility functions from JSON.org - Copyright 2002 JSON.org. +- PKCS#1 PEM encoded private key parsing and utility functions from oauth.googlecode.com - Copyright 1998-2010 AOL Inc. + +The licenses for these third party components are included in LICENSE.txt diff --git a/x-pack/plugin/inference/licenses/jaxb-LICENSE.txt b/x-pack/plugin/inference/licenses/jaxb-LICENSE.txt new file mode 100644 index 0000000000000..833a843cfeee1 --- /dev/null +++ b/x-pack/plugin/inference/licenses/jaxb-LICENSE.txt @@ -0,0 +1,274 @@ +COMMON DEVELOPMENT AND DISTRIBUTION LICENSE (CDDL)Version 1.1 + +1. Definitions. + + 1.1. "Contributor" means each individual or entity that creates or contributes to the creation of Modifications. + + 1.2. "Contributor Version" means the combination of the Original Software, prior Modifications used by a Contributor (if any), and the Modifications made by that particular Contributor. + + 1.3. "Covered Software" means (a) the Original Software, or (b) Modifications, or (c) the combination of files containing Original Software with files containing Modifications, in each case including portions thereof. + + 1.4. "Executable" means the Covered Software in any form other than Source Code. + + 1.5. "Initial Developer" means the individual or entity that first makes Original Software available under this License. + + 1.6. "Larger Work" means a work which combines Covered Software or portions thereof with code not governed by the terms of this License. + + 1.7. "License" means this document. + + 1.8. "Licensable" means having the right to grant, to the maximum extent possible, whether at the time of the initial grant or subsequently acquired, any and all of the rights conveyed herein. + + 1.9. "Modifications" means the Source Code and Executable form of any of the following: + + A. Any file that results from an addition to, deletion from or modification of the contents of a file containing Original Software or previous Modifications; + + B. Any new file that contains any part of the Original Software or previous Modification; or + + C. Any new file that is contributed or otherwise made available under the terms of this License. + + 1.10. "Original Software" means the Source Code and Executable form of computer software code that is originally released under this License. + + 1.11. "Patent Claims" means any patent claim(s), now owned or hereafter acquired, including without limitation, method, process, and apparatus claims, in any patent Licensable by grantor. + + 1.12. "Source Code" means (a) the common form of computer software code in which modifications are made and (b) associated documentation included in or with such code. + + 1.13. "You" (or "Your") means an individual or a legal entity exercising rights under, and complying with all of the terms of, this License. For legal entities, "You" includes any entity which controls, is controlled by, or is under common control with You. For purposes of this definition, "control" means (a) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (b) ownership of more than fifty percent (50%) of the outstanding shares or beneficial ownership of such entity. + +2. License Grants. + + 2.1. The Initial Developer Grant. + + Conditioned upon Your compliance with Section 3.1 below and subject to third party intellectual property claims, the Initial Developer hereby grants You a world-wide, royalty-free, non-exclusive license: + + (a) under intellectual property rights (other than patent or trademark) Licensable by Initial Developer, to use, reproduce, modify, display, perform, sublicense and distribute the Original Software (or portions thereof), with or without Modifications, and/or as part of a Larger Work; and + + (b) under Patent Claims infringed by the making, using or selling of Original Software, to make, have made, use, practice, sell, and offer for sale, and/or otherwise dispose of the Original Software (or portions thereof). + + (c) The licenses granted in Sections 2.1(a) and (b) are effective on the date Initial Developer first distributes or otherwise makes the Original Software available to a third party under the terms of this License. + + (d) Notwithstanding Section 2.1(b) above, no patent license is granted: (1) for code that You delete from the Original Software, or (2) for infringements caused by: (i) the modification of the Original Software, or (ii) the combination of the Original Software with other software or devices. + + 2.2. Contributor Grant. + + Conditioned upon Your compliance with Section 3.1 below and subject to third party intellectual property claims, each Contributor hereby grants You a world-wide, royalty-free, non-exclusive license: + + (a) under intellectual property rights (other than patent or trademark) Licensable by Contributor to use, reproduce, modify, display, perform, sublicense and distribute the Modifications created by such Contributor (or portions thereof), either on an unmodified basis, with other Modifications, as Covered Software and/or as part of a Larger Work; and + + (b) under Patent Claims infringed by the making, using, or selling of Modifications made by that Contributor either alone and/or in combination with its Contributor Version (or portions of such combination), to make, use, sell, offer for sale, have made, and/or otherwise dispose of: (1) Modifications made by that Contributor (or portions thereof); and (2) the combination of Modifications made by that Contributor with its Contributor Version (or portions of such combination). + + (c) The licenses granted in Sections 2.2(a) and 2.2(b) are effective on the date Contributor first distributes or otherwise makes the Modifications available to a third party. + + (d) Notwithstanding Section 2.2(b) above, no patent license is granted: (1) for any code that Contributor has deleted from the Contributor Version; (2) for infringements caused by: (i) third party modifications of Contributor Version, or (ii) the combination of Modifications made by that Contributor with other software (except as part of the Contributor Version) or other devices; or (3) under Patent Claims infringed by Covered Software in the absence of Modifications made by that Contributor. + +3. Distribution Obligations. + + 3.1. Availability of Source Code. + + Any Covered Software that You distribute or otherwise make available in Executable form must also be made available in Source Code form and that Source Code form must be distributed only under the terms of this License. You must include a copy of this License with every copy of the Source Code form of the Covered Software You distribute or otherwise make available. You must inform recipients of any such Covered Software in Executable form as to how they can obtain such Covered Software in Source Code form in a reasonable manner on or through a medium customarily used for software exchange. + + 3.2. Modifications. + + The Modifications that You create or to which You contribute are governed by the terms of this License. You represent that You believe Your Modifications are Your original creation(s) and/or You have sufficient rights to grant the rights conveyed by this License. + + 3.3. Required Notices. + + You must include a notice in each of Your Modifications that identifies You as the Contributor of the Modification. You may not remove or alter any copyright, patent or trademark notices contained within the Covered Software, or any notices of licensing or any descriptive text giving attribution to any Contributor or the Initial Developer. + + 3.4. Application of Additional Terms. + + You may not offer or impose any terms on any Covered Software in Source Code form that alters or restricts the applicable version of this License or the recipients' rights hereunder. You may choose to offer, and to charge a fee for, warranty, support, indemnity or liability obligations to one or more recipients of Covered Software. However, you may do so only on Your own behalf, and not on behalf of the Initial Developer or any Contributor. You must make it absolutely clear that any such warranty, support, indemnity or liability obligation is offered by You alone, and You hereby agree to indemnify the Initial Developer and every Contributor for any liability incurred by the Initial Developer or such Contributor as a result of warranty, support, indemnity or liability terms You offer. + + 3.5. Distribution of Executable Versions. + + You may distribute the Executable form of the Covered Software under the terms of this License or under the terms of a license of Your choice, which may contain terms different from this License, provided that You are in compliance with the terms of this License and that the license for the Executable form does not attempt to limit or alter the recipient's rights in the Source Code form from the rights set forth in this License. If You distribute the Covered Software in Executable form under a different license, You must make it absolutely clear that any terms which differ from this License are offered by You alone, not by the Initial Developer or Contributor. You hereby agree to indemnify the Initial Developer and every Contributor for any liability incurred by the Initial Developer or such Contributor as a result of any such terms You offer. + + 3.6. Larger Works. + + You may create a Larger Work by combining Covered Software with other code not governed by the terms of this License and distribute the Larger Work as a single product. In such a case, You must make sure the requirements of this License are fulfilled for the Covered Software. + +4. Versions of the License. + + 4.1. New Versions. + + Oracle is the initial license steward and may publish revised and/or new versions of this License from time to time. Each version will be given a distinguishing version number. Except as provided in Section 4.3, no one other than the license steward has the right to modify this License. + + 4.2. Effect of New Versions. + + You may always continue to use, distribute or otherwise make the Covered Software available under the terms of the version of the License under which You originally received the Covered Software. If the Initial Developer includes a notice in the Original Software prohibiting it from being distributed or otherwise made available under any subsequent version of the License, You must distribute and make the Covered Software available under the terms of the version of the License under which You originally received the Covered Software. Otherwise, You may also choose to use, distribute or otherwise make the Covered Software available under the terms of any subsequent version of the License published by the license steward. + + 4.3. Modified Versions. + + When You are an Initial Developer and You want to create a new license for Your Original Software, You may create and use a modified version of this License if You: (a) rename the license and remove any references to the name of the license steward (except to note that the license differs from this License); and (b) otherwise make it clear that the license contains terms which differ from this License. + +5. DISCLAIMER OF WARRANTY. + + COVERED SOFTWARE IS PROVIDED UNDER THIS LICENSE ON AN "AS IS" BASIS, WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, WITHOUT LIMITATION, WARRANTIES THAT THE COVERED SOFTWARE IS FREE OF DEFECTS, MERCHANTABLE, FIT FOR A PARTICULAR PURPOSE OR NON-INFRINGING. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE COVERED SOFTWARE IS WITH YOU. SHOULD ANY COVERED SOFTWARE PROVE DEFECTIVE IN ANY RESPECT, YOU (NOT THE INITIAL DEVELOPER OR ANY OTHER CONTRIBUTOR) ASSUME THE COST OF ANY NECESSARY SERVICING, REPAIR OR CORRECTION. THIS DISCLAIMER OF WARRANTY CONSTITUTES AN ESSENTIAL PART OF THIS LICENSE. NO USE OF ANY COVERED SOFTWARE IS AUTHORIZED HEREUNDER EXCEPT UNDER THIS DISCLAIMER. + +6. TERMINATION. + + 6.1. This License and the rights granted hereunder will terminate automatically if You fail to comply with terms herein and fail to cure such breach within 30 days of becoming aware of the breach. Provisions which, by their nature, must remain in effect beyond the termination of this License shall survive. + + 6.2. If You assert a patent infringement claim (excluding declaratory judgment actions) against Initial Developer or a Contributor (the Initial Developer or Contributor against whom You assert such claim is referred to as "Participant") alleging that the Participant Software (meaning the Contributor Version where the Participant is a Contributor or the Original Software where the Participant is the Initial Developer) directly or indirectly infringes any patent, then any and all rights granted directly or indirectly to You by such Participant, the Initial Developer (if the Initial Developer is not the Participant) and all Contributors under Sections 2.1 and/or 2.2 of this License shall, upon 60 days notice from Participant terminate prospectively and automatically at the expiration of such 60 day notice period, unless if within such 60 day period You withdraw Your claim with respect to the Participant Software against such Participant either unilaterally or pursuant to a written agreement with Participant. + + 6.3. If You assert a patent infringement claim against Participant alleging that the Participant Software directly or indirectly infringes any patent where such claim is resolved (such as by license or settlement) prior to the initiation of patent infringement litigation, then the reasonable value of the licenses granted by such Participant under Sections 2.1 or 2.2 shall be taken into account in determining the amount or value of any payment or license. + + 6.4. In the event of termination under Sections 6.1 or 6.2 above, all end user licenses that have been validly granted by You or any distributor hereunder prior to termination (excluding licenses granted to You by any distributor) shall survive termination. + +7. LIMITATION OF LIABILITY. + + UNDER NO CIRCUMSTANCES AND UNDER NO LEGAL THEORY, WHETHER TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE, SHALL YOU, THE INITIAL DEVELOPER, ANY OTHER CONTRIBUTOR, OR ANY DISTRIBUTOR OF COVERED SOFTWARE, OR ANY SUPPLIER OF ANY OF SUCH PARTIES, BE LIABLE TO ANY PERSON FOR ANY INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES OF ANY CHARACTER INCLUDING, WITHOUT LIMITATION, DAMAGES FOR LOSS OF GOODWILL, WORK STOPPAGE, COMPUTER FAILURE OR MALFUNCTION, OR ANY AND ALL OTHER COMMERCIAL DAMAGES OR LOSSES, EVEN IF SUCH PARTY SHALL HAVE BEEN INFORMED OF THE POSSIBILITY OF SUCH DAMAGES. THIS LIMITATION OF LIABILITY SHALL NOT APPLY TO LIABILITY FOR DEATH OR PERSONAL INJURY RESULTING FROM SUCH PARTY'S NEGLIGENCE TO THE EXTENT APPLICABLE LAW PROHIBITS SUCH LIMITATION. SOME JURISDICTIONS DO NOT ALLOW THE EXCLUSION OR LIMITATION OF INCIDENTAL OR CONSEQUENTIAL DAMAGES, SO THIS EXCLUSION AND LIMITATION MAY NOT APPLY TO YOU. + +8. U.S. GOVERNMENT END USERS. + + The Covered Software is a "commercial item," as that term is defined in 48 C.F.R. 2.101 (Oct. 1995), consisting of "commercial computer software" (as that term is defined at 48 C.F.R. ? 252.227-7014(a)(1)) and "commercial computer software documentation" as such terms are used in 48 C.F.R. 12.212 (Sept. 1995). Consistent with 48 C.F.R. 12.212 and 48 C.F.R. 227.7202-1 through 227.7202-4 (June 1995), all U.S. Government End Users acquire Covered Software with only those rights set forth herein. This U.S. Government Rights clause is in lieu of, and supersedes, any other FAR, DFAR, or other clause or provision that addresses Government rights in computer software under this License. + +9. MISCELLANEOUS. + + This License represents the complete agreement concerning subject matter hereof. If any provision of this License is held to be unenforceable, such provision shall be reformed only to the extent necessary to make it enforceable. This License shall be governed by the law of the jurisdiction specified in a notice contained within the Original Software (except to the extent applicable law, if any, provides otherwise), excluding such jurisdiction's conflict-of-law provisions. Any litigation relating to this License shall be subject to the jurisdiction of the courts located in the jurisdiction and venue specified in a notice contained within the Original Software, with the losing party responsible for costs, including, without limitation, court costs and reasonable attorneys' fees and expenses. The application of the United Nations Convention on Contracts for the International Sale of Goods is expressly excluded. Any law or regulation which provides that the language of a contract shall be construed against the drafter shall not apply to this License. You agree that You alone are responsible for compliance with the United States export administration regulations (and the export control laws and regulation of any other countries) when You use, distribute or otherwise make available any Covered Software. + +10. RESPONSIBILITY FOR CLAIMS. + + As between Initial Developer and the Contributors, each party is responsible for claims and damages arising, directly or indirectly, out of its utilization of rights under this License and You agree to work with Initial Developer and Contributors to distribute such responsibility on an equitable basis. Nothing herein is intended or shall be deemed to constitute any admission of liability. + +---------- +NOTICE PURSUANT TO SECTION 9 OF THE COMMON DEVELOPMENT AND DISTRIBUTION LICENSE (CDDL) +The code released under the CDDL shall be governed by the laws of the State of California (excluding conflict-of-law provisions). Any litigation relating to this License shall be subject to the jurisdiction of the Federal Courts of the Northern District of California and the state courts of the State of California, with venue lying in Santa Clara County, California. + + + + +The GNU General Public License (GPL) Version 2, June 1991 + + +Copyright (C) 1989, 1991 Free Software Foundation, Inc. 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + +Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed. + +Preamble + +The licenses for most software are designed to take away your freedom to share and change it. By contrast, the GNU General Public License is intended to guarantee your freedom to share and change free software--to make sure the software is free for all its users. This General Public License applies to most of the Free Software Foundation's software and to any other program whose authors commit to using it. (Some other Free Software Foundation software is covered by the GNU Library General Public License instead.) You can apply it to your programs, too. + +When we speak of free software, we are referring to freedom, not price. Our General Public Licenses are designed to make sure that you have the freedom to distribute copies of free software (and charge for this service if you wish), that you receive source code or can get it if you want it, that you can change the software or use pieces of it in new free programs; and that you know you can do these things. + +To protect your rights, we need to make restrictions that forbid anyone to deny you these rights or to ask you to surrender the rights. These restrictions translate to certain responsibilities for you if you distribute copies of the software, or if you modify it. + +For example, if you distribute copies of such a program, whether gratis or for a fee, you must give the recipients all the rights that you have. You must make sure that they, too, receive or can get the source code. And you must show them these terms so they know their rights. + +We protect your rights with two steps: (1) copyright the software, and (2) offer you this license which gives you legal permission to copy, distribute and/or modify the software. + +Also, for each author's protection and ours, we want to make certain that everyone understands that there is no warranty for this free software. If the software is modified by someone else and passed on, we want its recipients to know that what they have is not the original, so that any problems introduced by others will not reflect on the original authors' reputations. + +Finally, any free program is threatened constantly by software patents. We wish to avoid the danger that redistributors of a free program will individually obtain patent licenses, in effect making the program proprietary. To prevent this, we have made it clear that any patent must be licensed for everyone's free use or not licensed at all. + +The precise terms and conditions for copying, distribution and modification follow. + + +TERMS AND CONDITIONS FOR COPYING, DISTRIBUTION AND MODIFICATION + +0. This License applies to any program or other work which contains a notice placed by the copyright holder saying it may be distributed under the terms of this General Public License. The "Program", below, refers to any such program or work, and a "work based on the Program" means either the Program or any derivative work under copyright law: that is to say, a work containing the Program or a portion of it, either verbatim or with modifications and/or translated into another language. (Hereinafter, translation is included without limitation in the term "modification".) Each licensee is addressed as "you". + +Activities other than copying, distribution and modification are not covered by this License; they are outside its scope. The act of running the Program is not restricted, and the output from the Program is covered only if its contents constitute a work based on the Program (independent of having been made by running the Program). Whether that is true depends on what the Program does. + +1. You may copy and distribute verbatim copies of the Program's source code as you receive it, in any medium, provided that you conspicuously and appropriately publish on each copy an appropriate copyright notice and disclaimer of warranty; keep intact all the notices that refer to this License and to the absence of any warranty; and give any other recipients of the Program a copy of this License along with the Program. + +You may charge a fee for the physical act of transferring a copy, and you may at your option offer warranty protection in exchange for a fee. + +2. You may modify your copy or copies of the Program or any portion of it, thus forming a work based on the Program, and copy and distribute such modifications or work under the terms of Section 1 above, provided that you also meet all of these conditions: + + a) You must cause the modified files to carry prominent notices stating that you changed the files and the date of any change. + + b) You must cause any work that you distribute or publish, that in whole or in part contains or is derived from the Program or any part thereof, to be licensed as a whole at no charge to all third parties under the terms of this License. + + c) If the modified program normally reads commands interactively when run, you must cause it, when started running for such interactive use in the most ordinary way, to print or display an announcement including an appropriate copyright notice and a notice that there is no warranty (or else, saying that you provide a warranty) and that users may redistribute the program under these conditions, and telling the user how to view a copy of this License. (Exception: if the Program itself is interactive but does not normally print such an announcement, your work based on the Program is not required to print an announcement.) + +These requirements apply to the modified work as a whole. If identifiable sections of that work are not derived from the Program, and can be reasonably considered independent and separate works in themselves, then this License, and its terms, do not apply to those sections when you distribute them as separate works. But when you distribute the same sections as part of a whole which is a work based on the Program, the distribution of the whole must be on the terms of this License, whose permissions for other licensees extend to the entire whole, and thus to each and every part regardless of who wrote it. + +Thus, it is not the intent of this section to claim rights or contest your rights to work written entirely by you; rather, the intent is to exercise the right to control the distribution of derivative or collective works based on the Program. + +In addition, mere aggregation of another work not based on the Program with the Program (or with a work based on the Program) on a volume of a storage or distribution medium does not bring the other work under the scope of this License. + +3. You may copy and distribute the Program (or a work based on it, under Section 2) in object code or executable form under the terms of Sections 1 and 2 above provided that you also do one of the following: + + a) Accompany it with the complete corresponding machine-readable source code, which must be distributed under the terms of Sections 1 and 2 above on a medium customarily used for software interchange; or, + + b) Accompany it with a written offer, valid for at least three years, to give any third party, for a charge no more than your cost of physically performing source distribution, a complete machine-readable copy of the corresponding source code, to be distributed under the terms of Sections 1 and 2 above on a medium customarily used for software interchange; or, + + c) Accompany it with the information you received as to the offer to distribute corresponding source code. (This alternative is allowed only for noncommercial distribution and only if you received the program in object code or executable form with such an offer, in accord with Subsection b above.) + +The source code for a work means the preferred form of the work for making modifications to it. For an executable work, complete source code means all the source code for all modules it contains, plus any associated interface definition files, plus the scripts used to control compilation and installation of the executable. However, as a special exception, the source code distributed need not include anything that is normally distributed (in either source or binary form) with the major components (compiler, kernel, and so on) of the operating system on which the executable runs, unless that component itself accompanies the executable. + +If distribution of executable or object code is made by offering access to copy from a designated place, then offering equivalent access to copy the source code from the same place counts as distribution of the source code, even though third parties are not compelled to copy the source along with the object code. + +4. You may not copy, modify, sublicense, or distribute the Program except as expressly provided under this License. Any attempt otherwise to copy, modify, sublicense or distribute the Program is void, and will automatically terminate your rights under this License. However, parties who have received copies, or rights, from you under this License will not have their licenses terminated so long as such parties remain in full compliance. + +5. You are not required to accept this License, since you have not signed it. However, nothing else grants you permission to modify or distribute the Program or its derivative works. These actions are prohibited by law if you do not accept this License. Therefore, by modifying or distributing the Program (or any work based on the Program), you indicate your acceptance of this License to do so, and all its terms and conditions for copying, distributing or modifying the Program or works based on it. + +6. Each time you redistribute the Program (or any work based on the Program), the recipient automatically receives a license from the original licensor to copy, distribute or modify the Program subject to these terms and conditions. You may not impose any further restrictions on the recipients' exercise of the rights granted herein. You are not responsible for enforcing compliance by third parties to this License. + +7. If, as a consequence of a court judgment or allegation of patent infringement or for any other reason (not limited to patent issues), conditions are imposed on you (whether by court order, agreement or otherwise) that contradict the conditions of this License, they do not excuse you from the conditions of this License. If you cannot distribute so as to satisfy simultaneously your obligations under this License and any other pertinent obligations, then as a consequence you may not distribute the Program at all. For example, if a patent license would not permit royalty-free redistribution of the Program by all those who receive copies directly or indirectly through you, then the only way you could satisfy both it and this License would be to refrain entirely from distribution of the Program. + +If any portion of this section is held invalid or unenforceable under any particular circumstance, the balance of the section is intended to apply and the section as a whole is intended to apply in other circumstances. + +It is not the purpose of this section to induce you to infringe any patents or other property right claims or to contest validity of any such claims; this section has the sole purpose of protecting the integrity of the free software distribution system, which is implemented by public license practices. Many people have made generous contributions to the wide range of software distributed through that system in reliance on consistent application of that system; it is up to the author/donor to decide if he or she is willing to distribute software through any other system and a licensee cannot impose that choice. + +This section is intended to make thoroughly clear what is believed to be a consequence of the rest of this License. + +8. If the distribution and/or use of the Program is restricted in certain countries either by patents or by copyrighted interfaces, the original copyright holder who places the Program under this License may add an explicit geographical distribution limitation excluding those countries, so that distribution is permitted only in or among countries not thus excluded. In such case, this License incorporates the limitation as if written in the body of this License. + +9. The Free Software Foundation may publish revised and/or new versions of the General Public License from time to time. Such new versions will be similar in spirit to the present version, but may differ in detail to address new problems or concerns. + +Each version is given a distinguishing version number. If the Program specifies a version number of this License which applies to it and "any later version", you have the option of following the terms and conditions either of that version or of any later version published by the Free Software Foundation. If the Program does not specify a version number of this License, you may choose any version ever published by the Free Software Foundation. + +10. If you wish to incorporate parts of the Program into other free programs whose distribution conditions are different, write to the author to ask for permission. For software which is copyrighted by the Free Software Foundation, write to the Free Software Foundation; we sometimes make exceptions for this. Our decision will be guided by the two goals of preserving the free status of all derivatives of our free software and of promoting the sharing and reuse of software generally. + +NO WARRANTY + +11. BECAUSE THE PROGRAM IS LICENSED FREE OF CHARGE, THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + +12. IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MAY MODIFY AND/OR REDISTRIBUTE THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. + +END OF TERMS AND CONDITIONS + + +How to Apply These Terms to Your New Programs + +If you develop a new program, and you want it to be of the greatest possible use to the public, the best way to achieve this is to make it free software which everyone can redistribute and change under these terms. + +To do so, attach the following notices to the program. It is safest to attach them to the start of each source file to most effectively convey the exclusion of warranty; and each file should have at least the "copyright" line and a pointer to where the full notice is found. + + One line to give the program's name and a brief idea of what it does. + + Copyright (C) + + This program is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation; either version 2 of the License, or (at your option) any later version. + + This program is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. + + You should have received a copy of the GNU General Public License along with this program; if not, write to the Free Software Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA + +Also add information on how to contact you by electronic and paper mail. + +If the program is interactive, make it output a short notice like this when it starts in an interactive mode: + + Gnomovision version 69, Copyright (C) year name of author + Gnomovision comes with ABSOLUTELY NO WARRANTY; for details type `show w'. This is free software, and you are welcome to redistribute it under certain conditions; type `show c' for details. + +The hypothetical commands `show w' and `show c' should show the appropriate parts of the General Public License. Of course, the commands you use may be called something other than `show w' and `show c'; they could even be mouse-clicks or menu items--whatever suits your program. + +You should also get your employer (if you work as a programmer) or your school, if any, to sign a "copyright disclaimer" for the program, if necessary. Here is a sample; alter the names: + + Yoyodyne, Inc., hereby disclaims all copyright interest in the program `Gnomovision' (which makes passes at compilers) written by James Hacker. + + signature of Ty Coon, 1 April 1989 + Ty Coon, President of Vice + +This General Public License does not permit incorporating your program into proprietary programs. If your program is a subroutine library, you may consider it more useful to permit linking proprietary applications with the library. If this is what you want to do, use the GNU Library General Public License instead of this License. + + +"CLASSPATH" EXCEPTION TO THE GPL VERSION 2 + +Certain source files distributed by Oracle are subject to the following clarification and special exception to the GPL Version 2, but only where Oracle has expressly included in the particular source file's header the words "Oracle designates this particular file as subject to the "Classpath" exception as provided by Oracle in the License file that accompanied this code." + +Linking this library statically or dynamically with other modules is making a combined work based on this library. Thus, the terms and conditions of the GNU General Public License Version 2 cover the whole combination. + +As a special exception, the copyright holders of this library give you permission to link this library with independent modules to produce an executable, regardless of the license terms of these independent modules, and to copy and distribute the resulting executable under terms of your choice, provided that you also meet, for each linked independent module, the terms and conditions of the license of that module. An independent module is a module which is not derived from or based on this library. If you modify this library, you may extend this exception to your version of the library, but you are not obligated to do so. If you do not wish to do so, delete this exception statement from your version. diff --git a/x-pack/plugin/inference/licenses/jaxb-NOTICE.txt b/x-pack/plugin/inference/licenses/jaxb-NOTICE.txt new file mode 100644 index 0000000000000..8d1c8b69c3fce --- /dev/null +++ b/x-pack/plugin/inference/licenses/jaxb-NOTICE.txt @@ -0,0 +1 @@ + diff --git a/x-pack/plugin/core/licenses/nimbus-jose-jwt-LICENSE.txt b/x-pack/plugin/inference/licenses/joda-time-LICENSE.txt similarity index 100% rename from x-pack/plugin/core/licenses/nimbus-jose-jwt-LICENSE.txt rename to x-pack/plugin/inference/licenses/joda-time-LICENSE.txt diff --git a/x-pack/plugin/inference/licenses/joda-time-NOTICE.txt b/x-pack/plugin/inference/licenses/joda-time-NOTICE.txt new file mode 100644 index 0000000000000..dffbcf31cacf6 --- /dev/null +++ b/x-pack/plugin/inference/licenses/joda-time-NOTICE.txt @@ -0,0 +1,5 @@ +============================================================================= += NOTICE file corresponding to section 4d of the Apache License Version 2.0 = +============================================================================= +This product includes software developed by +Joda.org (http://www.joda.org/). diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 419869c0c4a5e..f30f2e8fe201a 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -126,6 +126,25 @@ protected void deleteModel(String modelId, TaskType taskType) throws IOException assertOkOrCreated(response); } + protected void putSemanticText(String endpointId, String indexName) throws IOException { + var request = new Request("PUT", Strings.format("%s", indexName)); + String body = Strings.format(""" + { + "mappings": { + "properties": { + "inference_field": { + "type": "semantic_text", + "inference_id": "%s" + } + } + } + } + """, endpointId); + request.setJsonEntity(body); + var response = client().performRequest(request); + assertOkOrCreated(response); + } + protected Map putModel(String modelId, String modelConfig, TaskType taskType) throws IOException { String endpoint = Strings.format("_inference/%s/%s", taskType, modelId); return putRequest(endpoint, modelConfig); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 75e392b6d155f..242f786e95364 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -16,6 +16,7 @@ import java.io.IOException; import java.util.List; +import java.util.Set; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.hasSize; @@ -124,14 +125,15 @@ public void testDeleteEndpointWhileReferencedByPipeline() throws IOException { putPipeline(pipelineId, endpointId); { + var errorString = new StringBuilder().append("Inference endpoint ") + .append(endpointId) + .append(" is referenced by pipelines: ") + .append(Set.of(pipelineId)) + .append(". ") + .append("Ensure that no pipelines are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint."); var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId)); - assertThat( - e.getMessage(), - containsString( - "Inference endpoint endpoint_referenced_by_pipeline is referenced by pipelines and cannot be deleted. " - + "Use `force` to delete it anyway, or use `dry_run` to list the pipelines that reference it." - ) - ); + assertThat(e.getMessage(), containsString(errorString.toString())); } { var response = deleteModel(endpointId, "dry_run=true"); @@ -146,4 +148,78 @@ public void testDeleteEndpointWhileReferencedByPipeline() throws IOException { } deletePipeline(pipelineId); } + + public void testDeleteEndpointWhileReferencedBySemanticText() throws IOException { + String endpointId = "endpoint_referenced_by_semantic_text"; + putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + String indexName = randomAlphaOfLength(10).toLowerCase(); + putSemanticText(endpointId, indexName); + { + + var errorString = new StringBuilder().append(" Inference endpoint ") + .append(endpointId) + .append(" is being used in the mapping for indexes: ") + .append(Set.of(indexName)) + .append(". ") + .append("Ensure that no index mappings are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint."); + var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId)); + assertThat(e.getMessage(), containsString(errorString.toString())); + } + { + var response = deleteModel(endpointId, "dry_run=true"); + var entityString = EntityUtils.toString(response.getEntity()); + assertThat(entityString, containsString("\"acknowledged\":false")); + assertThat(entityString, containsString(indexName)); + } + { + var response = deleteModel(endpointId, "force=true"); + var entityString = EntityUtils.toString(response.getEntity()); + assertThat(entityString, containsString("\"acknowledged\":true")); + } + deleteIndex(indexName); + } + + public void testDeleteEndpointWhileReferencedBySemanticTextAndPipeline() throws IOException { + String endpointId = "endpoint_referenced_by_semantic_text"; + putModel(endpointId, mockSparseServiceModelConfig(), TaskType.SPARSE_EMBEDDING); + String indexName = randomAlphaOfLength(10).toLowerCase(); + putSemanticText(endpointId, indexName); + var pipelineId = "pipeline_referencing_model"; + putPipeline(pipelineId, endpointId); + { + + var errorString = new StringBuilder().append("Inference endpoint ") + .append(endpointId) + .append(" is referenced by pipelines: ") + .append(Set.of(pipelineId)) + .append(". ") + .append("Ensure that no pipelines are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint.") + .append(" Inference endpoint ") + .append(endpointId) + .append(" is being used in the mapping for indexes: ") + .append(Set.of(indexName)) + .append(". ") + .append("Ensure that no index mappings are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint."); + + var e = expectThrows(ResponseException.class, () -> deleteModel(endpointId)); + assertThat(e.getMessage(), containsString(errorString.toString())); + } + { + var response = deleteModel(endpointId, "dry_run=true"); + var entityString = EntityUtils.toString(response.getEntity()); + assertThat(entityString, containsString("\"acknowledged\":false")); + assertThat(entityString, containsString(indexName)); + assertThat(entityString, containsString(pipelineId)); + } + { + var response = deleteModel(endpointId, "force=true"); + var entityString = EntityUtils.toString(response.getEntity()); + assertThat(entityString, containsString("\"acknowledged\":true")); + } + deletePipeline(pipelineId); + deleteIndex(indexName); + } } diff --git a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/AzureOpenAiServiceUpgradeIT.java b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/AzureOpenAiServiceUpgradeIT.java index d475fd099d4ac..f0196834b9175 100644 --- a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/AzureOpenAiServiceUpgradeIT.java +++ b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/AzureOpenAiServiceUpgradeIT.java @@ -49,35 +49,39 @@ public static void shutdown() { @AwaitsFix(bugUrl = "Cannot set the URL in the tests") public void testOpenAiEmbeddings() throws IOException { var openAiEmbeddingsSupported = getOldClusterTestVersion().onOrAfter(OPEN_AI_AZURE_EMBEDDINGS_ADDED); + // `gte_v` indicates that the cluster version is Greater Than or Equal to MODELS_RENAMED_TO_ENDPOINTS + String oldClusterEndpointIdentifier = oldClusterHasFeature("gte_v" + MODELS_RENAMED_TO_ENDPOINTS) ? "endpoints" : "models"; assumeTrue("Azure OpenAI embedding service added in " + OPEN_AI_AZURE_EMBEDDINGS_ADDED, openAiEmbeddingsSupported); final String oldClusterId = "old-cluster-embeddings"; final String upgradedClusterId = "upgraded-cluster-embeddings"; + var testTaskType = TaskType.TEXT_EMBEDDING; + if (isOldCluster()) { // queue a response as PUT will call the service openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(OpenAiServiceUpgradeIT.embeddingResponse())); - put(oldClusterId, embeddingConfig(getUrl(openAiEmbeddingsServer)), TaskType.TEXT_EMBEDDING); + put(oldClusterId, embeddingConfig(getUrl(openAiEmbeddingsServer)), testTaskType); - var configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterId).get("endpoints"); + var configs = (List>) get(testTaskType, oldClusterId).get(oldClusterEndpointIdentifier); assertThat(configs, hasSize(1)); } else if (isMixedCluster()) { - var configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterId).get("endpoints"); + var configs = getConfigsWithBreakingChangeHandling(testTaskType, oldClusterId); assertEquals("azureopenai", configs.get(0).get("service")); assertEmbeddingInference(oldClusterId); } else if (isUpgradedCluster()) { // check old cluster model - var configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterId).get("endpoints"); + var configs = (List>) get(testTaskType, oldClusterId).get("endpoints"); var serviceSettings = (Map) configs.get(0).get("service_settings"); // Inference on old cluster model assertEmbeddingInference(oldClusterId); openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(OpenAiServiceUpgradeIT.embeddingResponse())); - put(upgradedClusterId, embeddingConfig(getUrl(openAiEmbeddingsServer)), TaskType.TEXT_EMBEDDING); + put(upgradedClusterId, embeddingConfig(getUrl(openAiEmbeddingsServer)), testTaskType); - configs = (List>) get(TaskType.TEXT_EMBEDDING, upgradedClusterId).get("endpoints"); + configs = (List>) get(testTaskType, upgradedClusterId).get("endpoints"); assertThat(configs, hasSize(1)); // Inference on the new config diff --git a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java index c889d8f9b312a..c7d95f1f512b2 100644 --- a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java +++ b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/CohereServiceUpgradeIT.java @@ -15,6 +15,8 @@ import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingType; import org.hamcrest.Matchers; +import org.junit.AfterClass; +import org.junit.BeforeClass; import java.io.IOException; import java.util.List; @@ -39,7 +41,7 @@ public CohereServiceUpgradeIT(@Name("upgradedNodes") int upgradedNodes) { super(upgradedNodes); } - // @BeforeClass + @BeforeClass public static void startWebServer() throws IOException { cohereEmbeddingsServer = new MockWebServer(); cohereEmbeddingsServer.start(); @@ -48,58 +50,74 @@ public static void startWebServer() throws IOException { cohereRerankServer.start(); } - // @AfterClass // for the awaitsfix + @AfterClass public static void shutdown() { cohereEmbeddingsServer.close(); cohereRerankServer.close(); } @SuppressWarnings("unchecked") - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/107887") public void testCohereEmbeddings() throws IOException { var embeddingsSupported = getOldClusterTestVersion().onOrAfter(COHERE_EMBEDDINGS_ADDED); + // `gte_v` indicates that the cluster version is Greater Than or Equal to MODELS_RENAMED_TO_ENDPOINTS + String oldClusterEndpointIdentifier = oldClusterHasFeature("gte_v" + MODELS_RENAMED_TO_ENDPOINTS) ? "endpoints" : "models"; assumeTrue("Cohere embedding service added in " + COHERE_EMBEDDINGS_ADDED, embeddingsSupported); final String oldClusterIdInt8 = "old-cluster-embeddings-int8"; final String oldClusterIdFloat = "old-cluster-embeddings-float"; + var testTaskType = TaskType.TEXT_EMBEDDING; + if (isOldCluster()) { // queue a response as PUT will call the service cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte())); - put(oldClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING); + put(oldClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType); // float model cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat())); - put(oldClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING); + put(oldClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), testTaskType); - var configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterIdInt8).get("endpoints"); - assertThat(configs, hasSize(1)); - assertEquals("cohere", configs.get(0).get("service")); - var serviceSettings = (Map) configs.get(0).get("service_settings"); - assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0")); - var embeddingType = serviceSettings.get("embedding_type"); - // An upgraded node will report the embedding type as byte, the old node int8 - assertThat(embeddingType, Matchers.is(oneOf("int8", "byte"))); - - assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE); - assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT); + { + var configs = (List>) get(testTaskType, oldClusterIdInt8).get(oldClusterEndpointIdentifier); + assertThat(configs, hasSize(1)); + assertEquals("cohere", configs.get(0).get("service")); + var serviceSettings = (Map) configs.get(0).get("service_settings"); + assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0")); + var embeddingType = serviceSettings.get("embedding_type"); + // An upgraded node will report the embedding type as byte, the old node int8 + assertThat(embeddingType, Matchers.is(oneOf("int8", "byte"))); + assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE); + } + { + var configs = (List>) get(testTaskType, oldClusterIdFloat).get(oldClusterEndpointIdentifier); + assertThat(configs, hasSize(1)); + assertEquals("cohere", configs.get(0).get("service")); + var serviceSettings = (Map) configs.get(0).get("service_settings"); + assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0")); + assertThat(serviceSettings, hasEntry("embedding_type", "float")); + assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT); + } } else if (isMixedCluster()) { - var configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterIdInt8).get("endpoints"); - assertEquals("cohere", configs.get(0).get("service")); - var serviceSettings = (Map) configs.get(0).get("service_settings"); - assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0")); - var embeddingType = serviceSettings.get("embedding_type"); - // An upgraded node will report the embedding type as byte, an old node int8 - assertThat(embeddingType, Matchers.is(oneOf("int8", "byte"))); - - configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterIdFloat).get("endpoints"); - serviceSettings = (Map) configs.get(0).get("service_settings"); - assertThat(serviceSettings, hasEntry("embedding_type", "float")); - - assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE); - assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT); + { + var configs = getConfigsWithBreakingChangeHandling(testTaskType, oldClusterIdInt8); + assertEquals("cohere", configs.get(0).get("service")); + var serviceSettings = (Map) configs.get(0).get("service_settings"); + assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0")); + var embeddingType = serviceSettings.get("embedding_type"); + // An upgraded node will report the embedding type as byte, an old node int8 + assertThat(embeddingType, Matchers.is(oneOf("int8", "byte"))); + assertEmbeddingInference(oldClusterIdInt8, CohereEmbeddingType.BYTE); + } + { + var configs = getConfigsWithBreakingChangeHandling(testTaskType, oldClusterIdFloat); + assertEquals("cohere", configs.get(0).get("service")); + var serviceSettings = (Map) configs.get(0).get("service_settings"); + assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0")); + assertThat(serviceSettings, hasEntry("embedding_type", "float")); + assertEmbeddingInference(oldClusterIdFloat, CohereEmbeddingType.FLOAT); + } } else if (isUpgradedCluster()) { // check old cluster model - var configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterIdInt8).get("endpoints"); + var configs = (List>) get(testTaskType, oldClusterIdInt8).get("endpoints"); var serviceSettings = (Map) configs.get(0).get("service_settings"); assertThat(serviceSettings, hasEntry("model_id", "embed-english-light-v3.0")); assertThat(serviceSettings, hasEntry("embedding_type", "byte")); @@ -114,9 +132,9 @@ public void testCohereEmbeddings() throws IOException { final String upgradedClusterIdByte = "upgraded-cluster-embeddings-byte"; cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte())); - put(upgradedClusterIdByte, embeddingConfigByte(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING); + put(upgradedClusterIdByte, embeddingConfigByte(getUrl(cohereEmbeddingsServer)), testTaskType); - configs = (List>) get(TaskType.TEXT_EMBEDDING, upgradedClusterIdByte).get("endpoints"); + configs = (List>) get(testTaskType, upgradedClusterIdByte).get("endpoints"); serviceSettings = (Map) configs.get(0).get("service_settings"); assertThat(serviceSettings, hasEntry("embedding_type", "byte")); @@ -127,9 +145,9 @@ public void testCohereEmbeddings() throws IOException { final String upgradedClusterIdInt8 = "upgraded-cluster-embeddings-int8"; cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseByte())); - put(upgradedClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING); + put(upgradedClusterIdInt8, embeddingConfigInt8(getUrl(cohereEmbeddingsServer)), testTaskType); - configs = (List>) get(TaskType.TEXT_EMBEDDING, upgradedClusterIdInt8).get("endpoints"); + configs = (List>) get(testTaskType, upgradedClusterIdInt8).get("endpoints"); serviceSettings = (Map) configs.get(0).get("service_settings"); assertThat(serviceSettings, hasEntry("embedding_type", "byte")); // int8 rewritten to byte @@ -139,9 +157,9 @@ public void testCohereEmbeddings() throws IOException { { final String upgradedClusterIdFloat = "upgraded-cluster-embeddings-float"; cohereEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponseFloat())); - put(upgradedClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), TaskType.TEXT_EMBEDDING); + put(upgradedClusterIdFloat, embeddingConfigFloat(getUrl(cohereEmbeddingsServer)), testTaskType); - configs = (List>) get(TaskType.TEXT_EMBEDDING, upgradedClusterIdFloat).get("endpoints"); + configs = (List>) get(testTaskType, upgradedClusterIdFloat).get("endpoints"); serviceSettings = (Map) configs.get(0).get("service_settings"); assertThat(serviceSettings, hasEntry("embedding_type", "float")); @@ -169,22 +187,25 @@ void assertEmbeddingInference(String inferenceId, CohereEmbeddingType type) thro } @SuppressWarnings("unchecked") - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/107887") public void testRerank() throws IOException { var rerankSupported = getOldClusterTestVersion().onOrAfter(COHERE_RERANK_ADDED); + String old_cluster_endpoint_identifier = oldClusterHasFeature("gte_v" + MODELS_RENAMED_TO_ENDPOINTS) ? "endpoints" : "models"; assumeTrue("Cohere rerank service added in " + COHERE_RERANK_ADDED, rerankSupported); final String oldClusterId = "old-cluster-rerank"; final String upgradedClusterId = "upgraded-cluster-rerank"; + var testTaskType = TaskType.RERANK; + if (isOldCluster()) { - put(oldClusterId, rerankConfig(getUrl(cohereRerankServer)), TaskType.RERANK); - var configs = (List>) get(TaskType.RERANK, oldClusterId).get("endpoints"); + put(oldClusterId, rerankConfig(getUrl(cohereRerankServer)), testTaskType); + var configs = (List>) get(testTaskType, oldClusterId).get(old_cluster_endpoint_identifier); assertThat(configs, hasSize(1)); assertRerank(oldClusterId); } else if (isMixedCluster()) { - var configs = (List>) get(TaskType.RERANK, oldClusterId).get("endpoints"); + var configs = getConfigsWithBreakingChangeHandling(testTaskType, oldClusterId); + assertEquals("cohere", configs.get(0).get("service")); var serviceSettings = (Map) configs.get(0).get("service_settings"); assertThat(serviceSettings, hasEntry("model_id", "rerank-english-v3.0")); @@ -195,7 +216,7 @@ public void testRerank() throws IOException { } else if (isUpgradedCluster()) { // check old cluster model - var configs = (List>) get(TaskType.RERANK, oldClusterId).get("endpoints"); + var configs = (List>) get(testTaskType, oldClusterId).get("endpoints"); assertEquals("cohere", configs.get(0).get("service")); var serviceSettings = (Map) configs.get(0).get("service_settings"); assertThat(serviceSettings, hasEntry("model_id", "rerank-english-v3.0")); @@ -205,7 +226,7 @@ public void testRerank() throws IOException { assertRerank(oldClusterId); // New endpoint - put(upgradedClusterId, rerankConfig(getUrl(cohereRerankServer)), TaskType.RERANK); + put(upgradedClusterId, rerankConfig(getUrl(cohereRerankServer)), testTaskType); configs = (List>) get(upgradedClusterId).get("endpoints"); assertThat(configs, hasSize(1)); diff --git a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/HuggingFaceServiceUpgradeIT.java b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/HuggingFaceServiceUpgradeIT.java index 899a02776195d..36ee472cc0a13 100644 --- a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/HuggingFaceServiceUpgradeIT.java +++ b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/HuggingFaceServiceUpgradeIT.java @@ -13,6 +13,8 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; +import org.junit.AfterClass; +import org.junit.BeforeClass; import java.io.IOException; import java.util.List; @@ -34,7 +36,7 @@ public HuggingFaceServiceUpgradeIT(@Name("upgradedNodes") int upgradedNodes) { super(upgradedNodes); } - // @BeforeClass + @BeforeClass public static void startWebServer() throws IOException { embeddingsServer = new MockWebServer(); embeddingsServer.start(); @@ -43,47 +45,51 @@ public static void startWebServer() throws IOException { elserServer.start(); } - // @AfterClass for the awaits fix + @AfterClass public static void shutdown() { embeddingsServer.close(); elserServer.close(); } @SuppressWarnings("unchecked") - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/107887") public void testHFEmbeddings() throws IOException { var embeddingsSupported = getOldClusterTestVersion().onOrAfter(HF_EMBEDDINGS_ADDED); + // `gte_v` indicates that the cluster version is Greater Than or Equal to MODELS_RENAMED_TO_ENDPOINTS + String oldClusterEndpointIdentifier = oldClusterHasFeature("gte_v" + MODELS_RENAMED_TO_ENDPOINTS) ? "endpoints" : "models"; assumeTrue("Hugging Face embedding service added in " + HF_EMBEDDINGS_ADDED, embeddingsSupported); final String oldClusterId = "old-cluster-embeddings"; final String upgradedClusterId = "upgraded-cluster-embeddings"; + var testTaskType = TaskType.TEXT_EMBEDDING; + if (isOldCluster()) { // queue a response as PUT will call the service embeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse())); - put(oldClusterId, embeddingConfig(getUrl(embeddingsServer)), TaskType.TEXT_EMBEDDING); + put(oldClusterId, embeddingConfig(getUrl(embeddingsServer)), testTaskType); - var configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterId).get("endpoints"); + var configs = (List>) get(testTaskType, oldClusterId).get(oldClusterEndpointIdentifier); assertThat(configs, hasSize(1)); assertEmbeddingInference(oldClusterId); } else if (isMixedCluster()) { - var configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterId).get("endpoints"); + var configs = getConfigsWithBreakingChangeHandling(testTaskType, oldClusterId); + assertEquals("hugging_face", configs.get(0).get("service")); assertEmbeddingInference(oldClusterId); } else if (isUpgradedCluster()) { // check old cluster model - var configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterId).get("endpoints"); + var configs = (List>) get(testTaskType, oldClusterId).get("endpoints"); assertEquals("hugging_face", configs.get(0).get("service")); // Inference on old cluster model assertEmbeddingInference(oldClusterId); embeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse())); - put(upgradedClusterId, embeddingConfig(getUrl(embeddingsServer)), TaskType.TEXT_EMBEDDING); + put(upgradedClusterId, embeddingConfig(getUrl(embeddingsServer)), testTaskType); - configs = (List>) get(TaskType.TEXT_EMBEDDING, upgradedClusterId).get("endpoints"); + configs = (List>) get(testTaskType, upgradedClusterId).get("endpoints"); assertThat(configs, hasSize(1)); assertEmbeddingInference(upgradedClusterId); @@ -100,27 +106,29 @@ void assertEmbeddingInference(String inferenceId) throws IOException { } @SuppressWarnings("unchecked") - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/107887") public void testElser() throws IOException { var supported = getOldClusterTestVersion().onOrAfter(HF_ELSER_ADDED); + String old_cluster_endpoint_identifier = oldClusterHasFeature("gte_v" + MODELS_RENAMED_TO_ENDPOINTS) ? "endpoints" : "models"; assumeTrue("HF elser service added in " + HF_ELSER_ADDED, supported); final String oldClusterId = "old-cluster-elser"; final String upgradedClusterId = "upgraded-cluster-elser"; + var testTaskType = TaskType.SPARSE_EMBEDDING; + if (isOldCluster()) { - put(oldClusterId, elserConfig(getUrl(elserServer)), TaskType.SPARSE_EMBEDDING); - var configs = (List>) get(TaskType.SPARSE_EMBEDDING, oldClusterId).get("endpoints"); + put(oldClusterId, elserConfig(getUrl(elserServer)), testTaskType); + var configs = (List>) get(testTaskType, oldClusterId).get(old_cluster_endpoint_identifier); assertThat(configs, hasSize(1)); assertElser(oldClusterId); } else if (isMixedCluster()) { - var configs = (List>) get(TaskType.SPARSE_EMBEDDING, oldClusterId).get("endpoints"); + var configs = getConfigsWithBreakingChangeHandling(testTaskType, oldClusterId); assertEquals("hugging_face", configs.get(0).get("service")); assertElser(oldClusterId); } else if (isUpgradedCluster()) { // check old cluster model - var configs = (List>) get(TaskType.SPARSE_EMBEDDING, oldClusterId).get("endpoints"); + var configs = (List>) get(testTaskType, oldClusterId).get("endpoints"); assertEquals("hugging_face", configs.get(0).get("service")); var taskSettings = (Map) configs.get(0).get("task_settings"); assertThat(taskSettings.keySet(), empty()); @@ -128,7 +136,7 @@ public void testElser() throws IOException { assertElser(oldClusterId); // New endpoint - put(upgradedClusterId, elserConfig(getUrl(elserServer)), TaskType.SPARSE_EMBEDDING); + put(upgradedClusterId, elserConfig(getUrl(elserServer)), testTaskType); configs = (List>) get(upgradedClusterId).get("endpoints"); assertThat(configs, hasSize(1)); diff --git a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/InferenceUpgradeTestCase.java b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/InferenceUpgradeTestCase.java index ecfec2304c8a1..58335eb53b366 100644 --- a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/InferenceUpgradeTestCase.java +++ b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/InferenceUpgradeTestCase.java @@ -16,13 +16,17 @@ import org.elasticsearch.upgrades.AbstractRollingUpgradeTestCase; import java.io.IOException; +import java.util.LinkedList; import java.util.List; import java.util.Map; +import java.util.Objects; import static org.elasticsearch.core.Strings.format; public class InferenceUpgradeTestCase extends AbstractRollingUpgradeTestCase { + static final String MODELS_RENAMED_TO_ENDPOINTS = "8.15.0"; + public InferenceUpgradeTestCase(@Name("upgradedNodes") int upgradedNodes) { super(upgradedNodes); } @@ -104,4 +108,17 @@ protected void put(String inferenceId, String modelConfig, TaskType taskType) th var response = client().performRequest(request); assertOKAndConsume(response); } + + @SuppressWarnings("unchecked") + // in version 8.15, there was a breaking change where "models" was renamed to "endpoints" + LinkedList> getConfigsWithBreakingChangeHandling(TaskType testTaskType, String oldClusterId) throws IOException { + + LinkedList> configs; + configs = new LinkedList<>( + (List>) Objects.requireNonNullElse((get(testTaskType, oldClusterId).get("endpoints")), List.of()) + ); + configs.addAll(Objects.requireNonNullElse((List>) get(testTaskType, oldClusterId).get("models"), List.of())); + + return configs; + } } diff --git a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/OpenAiServiceUpgradeIT.java b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/OpenAiServiceUpgradeIT.java index bfdcb0e0d5ed4..df995c6f5e620 100644 --- a/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/OpenAiServiceUpgradeIT.java +++ b/x-pack/plugin/inference/qa/rolling-upgrade/src/javaRestTest/java/org/elasticsearch/xpack/application/OpenAiServiceUpgradeIT.java @@ -12,6 +12,8 @@ import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; +import org.junit.AfterClass; +import org.junit.BeforeClass; import java.io.IOException; import java.util.List; @@ -35,7 +37,7 @@ public OpenAiServiceUpgradeIT(@Name("upgradedNodes") int upgradedNodes) { super(upgradedNodes); } - // @BeforeClass + @BeforeClass public static void startWebServer() throws IOException { openAiEmbeddingsServer = new MockWebServer(); openAiEmbeddingsServer.start(); @@ -44,33 +46,37 @@ public static void startWebServer() throws IOException { openAiChatCompletionsServer.start(); } - // @AfterClass for the awaits fix + @AfterClass public static void shutdown() { openAiEmbeddingsServer.close(); openAiChatCompletionsServer.close(); } @SuppressWarnings("unchecked") - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/107887") public void testOpenAiEmbeddings() throws IOException { var openAiEmbeddingsSupported = getOldClusterTestVersion().onOrAfter(OPEN_AI_EMBEDDINGS_ADDED); + // `gte_v` indicates that the cluster version is Greater Than or Equal to MODELS_RENAMED_TO_ENDPOINTS + String oldClusterEndpointIdentifier = oldClusterHasFeature("gte_v" + MODELS_RENAMED_TO_ENDPOINTS) ? "endpoints" : "models"; assumeTrue("OpenAI embedding service added in " + OPEN_AI_EMBEDDINGS_ADDED, openAiEmbeddingsSupported); final String oldClusterId = "old-cluster-embeddings"; final String upgradedClusterId = "upgraded-cluster-embeddings"; + var testTaskType = TaskType.TEXT_EMBEDDING; + if (isOldCluster()) { String inferenceConfig = oldClusterVersionCompatibleEmbeddingConfig(); // queue a response as PUT will call the service openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse())); - put(oldClusterId, inferenceConfig, TaskType.TEXT_EMBEDDING); + put(oldClusterId, inferenceConfig, testTaskType); - var configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterId).get("endpoints"); + var configs = (List>) get(testTaskType, oldClusterId).get(oldClusterEndpointIdentifier); assertThat(configs, hasSize(1)); assertEmbeddingInference(oldClusterId); } else if (isMixedCluster()) { - var configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterId).get("endpoints"); + var configs = getConfigsWithBreakingChangeHandling(testTaskType, oldClusterId); + assertEquals("openai", configs.get(0).get("service")); var serviceSettings = (Map) configs.get(0).get("service_settings"); var taskSettings = (Map) configs.get(0).get("task_settings"); @@ -80,7 +86,7 @@ public void testOpenAiEmbeddings() throws IOException { assertEmbeddingInference(oldClusterId); } else if (isUpgradedCluster()) { // check old cluster model - var configs = (List>) get(TaskType.TEXT_EMBEDDING, oldClusterId).get("endpoints"); + var configs = (List>) get(testTaskType, oldClusterId).get("endpoints"); var serviceSettings = (Map) configs.get(0).get("service_settings"); // model id is moved to service settings assertThat(serviceSettings, hasEntry("model_id", "text-embedding-ada-002")); @@ -92,9 +98,9 @@ public void testOpenAiEmbeddings() throws IOException { String inferenceConfig = embeddingConfigWithModelInServiceSettings(getUrl(openAiEmbeddingsServer)); openAiEmbeddingsServer.enqueue(new MockResponse().setResponseCode(200).setBody(embeddingResponse())); - put(upgradedClusterId, inferenceConfig, TaskType.TEXT_EMBEDDING); + put(upgradedClusterId, inferenceConfig, testTaskType); - configs = (List>) get(TaskType.TEXT_EMBEDDING, upgradedClusterId).get("endpoints"); + configs = (List>) get(testTaskType, upgradedClusterId).get("endpoints"); assertThat(configs, hasSize(1)); assertEmbeddingInference(upgradedClusterId); @@ -111,23 +117,29 @@ void assertEmbeddingInference(String inferenceId) throws IOException { } @SuppressWarnings("unchecked") - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/107887") public void testOpenAiCompletions() throws IOException { var openAiEmbeddingsSupported = getOldClusterTestVersion().onOrAfter(OPEN_AI_COMPLETIONS_ADDED); + String old_cluster_endpoint_identifier = oldClusterHasFeature("gte_v" + MODELS_RENAMED_TO_ENDPOINTS) ? "endpoints" : "models"; assumeTrue("OpenAI completions service added in " + OPEN_AI_COMPLETIONS_ADDED, openAiEmbeddingsSupported); final String oldClusterId = "old-cluster-completions"; final String upgradedClusterId = "upgraded-cluster-completions"; + var testTaskType = TaskType.COMPLETION; + if (isOldCluster()) { - put(oldClusterId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), TaskType.COMPLETION); + put(oldClusterId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), testTaskType); - var configs = (List>) get(TaskType.COMPLETION, oldClusterId).get("endpoints"); + var configs = (List>) get(testTaskType, oldClusterId).get(old_cluster_endpoint_identifier); assertThat(configs, hasSize(1)); assertCompletionInference(oldClusterId); } else if (isMixedCluster()) { - var configs = (List>) get(TaskType.COMPLETION, oldClusterId).get("endpoints"); + var configs = (List>) get(testTaskType, oldClusterId).get("endpoints"); + if (oldClusterHasFeature("gte_v" + MODELS_RENAMED_TO_ENDPOINTS) == false) { + configs.addAll((List>) get(testTaskType, oldClusterId).get(old_cluster_endpoint_identifier)); + // in version 8.15, there was a breaking change where "models" was renamed to "endpoints" + } assertEquals("openai", configs.get(0).get("service")); var serviceSettings = (Map) configs.get(0).get("service_settings"); assertThat(serviceSettings, hasEntry("model_id", "gpt-4")); @@ -137,7 +149,7 @@ public void testOpenAiCompletions() throws IOException { assertCompletionInference(oldClusterId); } else if (isUpgradedCluster()) { // check old cluster model - var configs = (List>) get(TaskType.COMPLETION, oldClusterId).get("endpoints"); + var configs = (List>) get(testTaskType, oldClusterId).get("endpoints"); var serviceSettings = (Map) configs.get(0).get("service_settings"); assertThat(serviceSettings, hasEntry("model_id", "gpt-4")); var taskSettings = (Map) configs.get(0).get("task_settings"); @@ -145,8 +157,8 @@ public void testOpenAiCompletions() throws IOException { assertCompletionInference(oldClusterId); - put(upgradedClusterId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), TaskType.COMPLETION); - configs = (List>) get(TaskType.COMPLETION, upgradedClusterId).get("endpoints"); + put(upgradedClusterId, chatCompletionsConfig(getUrl(openAiChatCompletionsServer)), testTaskType); + configs = (List>) get(testTaskType, upgradedClusterId).get("endpoints"); assertThat(configs, hasSize(1)); // Inference on the new config diff --git a/x-pack/plugin/inference/src/main/java/module-info.java b/x-pack/plugin/inference/src/main/java/module-info.java index aa907a236884a..a7e5718a0920e 100644 --- a/x-pack/plugin/inference/src/main/java/module-info.java +++ b/x-pack/plugin/inference/src/main/java/module-info.java @@ -20,8 +20,13 @@ requires org.apache.lucene.join; requires com.ibm.icu; requires com.google.auth.oauth2; + requires com.google.auth; requires com.google.api.client; requires com.google.gson; + requires aws.java.sdk.bedrockruntime; + requires aws.java.sdk.core; + requires com.fasterxml.jackson.databind; + requires org.joda.time; exports org.elasticsearch.xpack.inference.action; exports org.elasticsearch.xpack.inference.registry; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index f3799b824fc0e..f8ce9df1fb194 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -24,6 +24,10 @@ import org.elasticsearch.xpack.core.inference.results.LegacyTextEmbeddingResults; import org.elasticsearch.xpack.core.inference.results.RankedDocsResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettings; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.anthropic.completion.AnthropicChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionServiceSettings; @@ -122,10 +126,46 @@ public static List getNamedWriteables() { addMistralNamedWriteables(namedWriteables); addCustomElandWriteables(namedWriteables); addAnthropicNamedWritables(namedWriteables); + addAmazonBedrockNamedWriteables(namedWriteables); return namedWriteables; } + private static void addAmazonBedrockNamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry( + AmazonBedrockSecretSettings.class, + AmazonBedrockSecretSettings.NAME, + AmazonBedrockSecretSettings::new + ) + ); + + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + AmazonBedrockEmbeddingsServiceSettings.NAME, + AmazonBedrockEmbeddingsServiceSettings::new + ) + ); + + // no task settings for Amazon Bedrock Embeddings + + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + AmazonBedrockChatCompletionServiceSettings.NAME, + AmazonBedrockChatCompletionServiceSettings::new + ) + ); + namedWriteables.add( + new NamedWriteableRegistry.Entry( + TaskSettings.class, + AmazonBedrockChatCompletionTaskSettings.NAME, + AmazonBedrockChatCompletionTaskSettings::new + ) + ); + } + private static void addMistralNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 1db5b4135ee94..1c388f7399260 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -53,6 +53,7 @@ import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.HttpSettings; import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; @@ -70,6 +71,7 @@ import org.elasticsearch.xpack.inference.rest.RestInferenceAction; import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockService; import org.elasticsearch.xpack.inference.services.anthropic.AnthropicService; import org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioService; import org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiService; @@ -117,6 +119,7 @@ public class InferencePlugin extends Plugin implements ActionPlugin, ExtensibleP private final Settings settings; private final SetOnce httpFactory = new SetOnce<>(); + private final SetOnce amazonBedrockFactory = new SetOnce<>(); private final SetOnce serviceComponents = new SetOnce<>(); private final SetOnce inferenceServiceRegistry = new SetOnce<>(); @@ -170,6 +173,9 @@ public Collection createComponents(PluginServices services) { var httpRequestSenderFactory = new HttpRequestSender.Factory(serviceComponents.get(), httpClientManager, services.clusterService()); httpFactory.set(httpRequestSenderFactory); + var amazonBedrockRequestSenderFactory = new AmazonBedrockRequestSender.Factory(serviceComponents.get(), services.clusterService()); + amazonBedrockFactory.set(amazonBedrockRequestSenderFactory); + ModelRegistry modelRegistry = new ModelRegistry(services.client()); if (inferenceServiceExtensions == null) { @@ -209,6 +215,7 @@ public List getInferenceServiceFactories() { context -> new GoogleVertexAiService(httpFactory.get(), serviceComponents.get()), context -> new MistralService(httpFactory.get(), serviceComponents.get()), context -> new AnthropicService(httpFactory.get(), serviceComponents.get()), + context -> new AmazonBedrockService(httpFactory.get(), amazonBedrockFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java index 07d5e1e618578..e59ac4e1356f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportDeleteInferenceEndpointAction.java @@ -3,6 +3,8 @@ * or more contributor license agreements. Licensed under the Elastic License * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. + * + * this file was contributed to by a Generative AI */ package org.elasticsearch.xpack.inference.action; @@ -11,6 +13,7 @@ import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRunnable; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.action.support.master.TransportMasterNodeAction; @@ -18,12 +21,10 @@ import org.elasticsearch.cluster.block.ClusterBlockException; import org.elasticsearch.cluster.block.ClusterBlockLevel; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; -import org.elasticsearch.cluster.metadata.Metadata; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.util.concurrent.EsExecutors; import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.ingest.IngestMetadata; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; @@ -34,6 +35,10 @@ import org.elasticsearch.xpack.inference.registry.ModelRegistry; import java.util.Set; +import java.util.concurrent.Executor; + +import static org.elasticsearch.xpack.core.ml.utils.SemanticTextInfoExtractor.extractIndexesReferencingInferenceEndpoints; +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeAction< DeleteInferenceEndpointAction.Request, @@ -42,6 +47,7 @@ public class TransportDeleteInferenceEndpointAction extends TransportMasterNodeA private final ModelRegistry modelRegistry; private final InferenceServiceRegistry serviceRegistry; private static final Logger logger = LogManager.getLogger(TransportDeleteInferenceEndpointAction.class); + private final Executor executor; @Inject public TransportDeleteInferenceEndpointAction( @@ -66,6 +72,7 @@ public TransportDeleteInferenceEndpointAction( ); this.modelRegistry = modelRegistry; this.serviceRegistry = serviceRegistry; + this.executor = threadPool.executor(UTILITY_THREAD_POOL_NAME); } @Override @@ -74,6 +81,15 @@ protected void masterOperation( DeleteInferenceEndpointAction.Request request, ClusterState state, ActionListener masterListener + ) { + // workaround for https://github.com/elastic/elasticsearch/issues/97916 - TODO remove this when we can + executor.execute(ActionRunnable.wrap(masterListener, l -> doExecuteForked(request, state, l))); + } + + private void doExecuteForked( + DeleteInferenceEndpointAction.Request request, + ClusterState state, + ActionListener masterListener ) { SubscribableListener.newForked(modelConfigListener -> { // Get the model from the registry @@ -89,17 +105,15 @@ protected void masterOperation( } if (request.isDryRun()) { - masterListener.onResponse( - new DeleteInferenceEndpointAction.Response( - false, - InferenceProcessorInfoExtractor.pipelineIdsForResource(state, Set.of(request.getInferenceEndpointId())) - ) - ); + handleDryRun(request, state, masterListener); return; - } else if (request.isForceDelete() == false - && endpointIsReferencedInPipelines(state, request.getInferenceEndpointId(), listener)) { + } else if (request.isForceDelete() == false) { + var errorString = endpointIsReferencedInPipelinesOrIndexes(state, request.getInferenceEndpointId()); + if (errorString != null) { + listener.onFailure(new ElasticsearchStatusException(errorString, RestStatus.CONFLICT)); return; } + } var service = serviceRegistry.getService(unparsedModel.service()); if (service.isPresent()) { @@ -126,47 +140,83 @@ && endpointIsReferencedInPipelines(state, request.getInferenceEndpointId(), list }) .addListener( masterListener.delegateFailure( - (l3, didDeleteModel) -> masterListener.onResponse(new DeleteInferenceEndpointAction.Response(didDeleteModel, Set.of())) + (l3, didDeleteModel) -> masterListener.onResponse( + new DeleteInferenceEndpointAction.Response(didDeleteModel, Set.of(), Set.of(), null) + ) ) ); } - private static boolean endpointIsReferencedInPipelines( - final ClusterState state, - final String inferenceEndpointId, - ActionListener listener + private static void handleDryRun( + DeleteInferenceEndpointAction.Request request, + ClusterState state, + ActionListener masterListener ) { - Metadata metadata = state.getMetadata(); - if (metadata == null) { - listener.onFailure( - new ElasticsearchStatusException( - " Could not determine if the endpoint is referenced in a pipeline as cluster state metadata was unexpectedly null. " - + "Use `force` to delete it anyway", - RestStatus.INTERNAL_SERVER_ERROR - ) - ); - // Unsure why the ClusterState metadata would ever be null, but in this case it seems safer to assume the endpoint is referenced - return true; + Set pipelines = InferenceProcessorInfoExtractor.pipelineIdsForResource(state, Set.of(request.getInferenceEndpointId())); + + Set indexesReferencedBySemanticText = extractIndexesReferencingInferenceEndpoints( + state.getMetadata(), + Set.of(request.getInferenceEndpointId()) + ); + + masterListener.onResponse( + new DeleteInferenceEndpointAction.Response( + false, + pipelines, + indexesReferencedBySemanticText, + buildErrorString(request.getInferenceEndpointId(), pipelines, indexesReferencedBySemanticText) + ) + ); + } + + private static String endpointIsReferencedInPipelinesOrIndexes(final ClusterState state, final String inferenceEndpointId) { + + var pipelines = endpointIsReferencedInPipelines(state, inferenceEndpointId); + var indexes = endpointIsReferencedInIndex(state, inferenceEndpointId); + + if (pipelines.isEmpty() == false || indexes.isEmpty() == false) { + return buildErrorString(inferenceEndpointId, pipelines, indexes); } - IngestMetadata ingestMetadata = metadata.custom(IngestMetadata.TYPE); - if (ingestMetadata == null) { - logger.debug("No ingest metadata found in cluster state while attempting to delete inference endpoint"); - } else { - Set modelIdsReferencedByPipelines = InferenceProcessorInfoExtractor.getModelIdsFromInferenceProcessors(ingestMetadata); - if (modelIdsReferencedByPipelines.contains(inferenceEndpointId)) { - listener.onFailure( - new ElasticsearchStatusException( - "Inference endpoint " - + inferenceEndpointId - + " is referenced by pipelines and cannot be deleted. " - + "Use `force` to delete it anyway, or use `dry_run` to list the pipelines that reference it.", - RestStatus.CONFLICT - ) - ); - return true; - } + return null; + } + + private static String buildErrorString(String inferenceEndpointId, Set pipelines, Set indexes) { + StringBuilder errorString = new StringBuilder(); + + if (pipelines.isEmpty() == false) { + errorString.append("Inference endpoint ") + .append(inferenceEndpointId) + .append(" is referenced by pipelines: ") + .append(pipelines) + .append(". ") + .append("Ensure that no pipelines are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint."); } - return false; + + if (indexes.isEmpty() == false) { + errorString.append(" Inference endpoint ") + .append(inferenceEndpointId) + .append(" is being used in the mapping for indexes: ") + .append(indexes) + .append(". ") + .append("Ensure that no index mappings are using this inference endpoint, ") + .append("or use force to ignore this warning and delete the inference endpoint."); + } + + return errorString.toString(); + } + + private static Set endpointIsReferencedInIndex(final ClusterState state, final String inferenceEndpointId) { + Set indexes = extractIndexesReferencingInferenceEndpoints(state.getMetadata(), Set.of(inferenceEndpointId)); + return indexes; + } + + private static Set endpointIsReferencedInPipelines(final ClusterState state, final String inferenceEndpointId) { + Set modelIdsReferencedByPipelines = InferenceProcessorInfoExtractor.pipelineIdsForResource( + state, + Set.of(inferenceEndpointId) + ); + return modelIdsReferencedByPipelines; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreator.java new file mode 100644 index 0000000000000..5f9fc532e33b2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreator.java @@ -0,0 +1,56 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.amazonbedrock; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockChatCompletionRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; + +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; + +public class AmazonBedrockActionCreator implements AmazonBedrockActionVisitor { + private final Sender sender; + private final ServiceComponents serviceComponents; + private final TimeValue timeout; + + public AmazonBedrockActionCreator(Sender sender, ServiceComponents serviceComponents, @Nullable TimeValue timeout) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + this.timeout = timeout; + } + + @Override + public ExecutableAction create(AmazonBedrockEmbeddingsModel embeddingsModel, Map taskSettings) { + var overriddenModel = AmazonBedrockEmbeddingsModel.of(embeddingsModel, taskSettings); + var requestManager = new AmazonBedrockEmbeddingsRequestManager( + overriddenModel, + serviceComponents.truncator(), + serviceComponents.threadPool(), + timeout + ); + var errorMessage = constructFailedToSendRequestMessage(null, "Amazon Bedrock embeddings"); + return new AmazonBedrockEmbeddingsAction(sender, requestManager, errorMessage); + } + + @Override + public ExecutableAction create(AmazonBedrockChatCompletionModel completionModel, Map taskSettings) { + var overriddenModel = AmazonBedrockChatCompletionModel.of(completionModel, taskSettings); + var requestManager = new AmazonBedrockChatCompletionRequestManager(overriddenModel, serviceComponents.threadPool(), timeout); + var errorMessage = constructFailedToSendRequestMessage(null, "Amazon Bedrock completion"); + return new AmazonBedrockChatCompletionAction(sender, requestManager, errorMessage); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionVisitor.java new file mode 100644 index 0000000000000..b540d030eb3f7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionVisitor.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.amazonbedrock; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; + +import java.util.Map; + +public interface AmazonBedrockActionVisitor { + ExecutableAction create(AmazonBedrockEmbeddingsModel embeddingsModel, Map taskSettings); + + ExecutableAction create(AmazonBedrockChatCompletionModel completionModel, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockChatCompletionAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockChatCompletionAction.java new file mode 100644 index 0000000000000..9d3c39d3ac4d9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockChatCompletionAction.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.amazonbedrock; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; + +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class AmazonBedrockChatCompletionAction implements ExecutableAction { + private final Sender sender; + private final RequestManager requestManager; + private final String errorMessage; + + public AmazonBedrockChatCompletionAction(Sender sender, RequestManager requestManager, String errorMessage) { + this.sender = Objects.requireNonNull(sender); + this.requestManager = Objects.requireNonNull(requestManager); + this.errorMessage = Objects.requireNonNull(errorMessage); + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener); + + sender.send(requestManager, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, errorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockEmbeddingsAction.java new file mode 100644 index 0000000000000..3f8be0c3cccbe --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockEmbeddingsAction.java @@ -0,0 +1,48 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.amazonbedrock; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; + +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class AmazonBedrockEmbeddingsAction implements ExecutableAction { + + private final Sender sender; + private final RequestManager requestManager; + private final String errorMessage; + + public AmazonBedrockEmbeddingsAction(Sender sender, RequestManager requestManager, String errorMessage) { + this.sender = Objects.requireNonNull(sender); + this.requestManager = Objects.requireNonNull(requestManager); + this.errorMessage = Objects.requireNonNull(errorMessage); + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener); + + sender.send(requestManager, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, errorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockBaseClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockBaseClient.java new file mode 100644 index 0000000000000..f9e403582a0ec --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockBaseClient.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; + +import java.time.Clock; +import java.util.Objects; + +public abstract class AmazonBedrockBaseClient implements AmazonBedrockClient { + protected final Integer modelKeysAndRegionHashcode; + protected Clock clock = Clock.systemUTC(); + + protected AmazonBedrockBaseClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + Objects.requireNonNull(model); + this.modelKeysAndRegionHashcode = getModelKeysAndRegionHashcode(model, timeout); + } + + public static Integer getModelKeysAndRegionHashcode(AmazonBedrockModel model, @Nullable TimeValue timeout) { + var secretSettings = model.getSecretSettings(); + var serviceSettings = model.getServiceSettings(); + return Objects.hash(secretSettings.accessKey, secretSettings.secretKey, serviceSettings.region(), timeout); + } + + public final void setClock(Clock clock) { + this.clock = clock; + } + + abstract void close(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockChatCompletionExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockChatCompletionExecutor.java new file mode 100644 index 0000000000000..a4e0c399517c1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockChatCompletionExecutor.java @@ -0,0 +1,43 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion.AmazonBedrockChatCompletionResponseListener; + +import java.util.function.Supplier; + +public class AmazonBedrockChatCompletionExecutor extends AmazonBedrockExecutor { + private final AmazonBedrockChatCompletionRequest chatCompletionRequest; + + protected AmazonBedrockChatCompletionExecutor( + AmazonBedrockChatCompletionRequest request, + AmazonBedrockResponseHandler responseHandler, + Logger logger, + Supplier hasRequestCompletedFunction, + ActionListener inferenceResultsListener, + AmazonBedrockClientCache clientCache + ) { + super(request, responseHandler, logger, hasRequestCompletedFunction, inferenceResultsListener, clientCache); + this.chatCompletionRequest = request; + } + + @Override + protected void executeClientRequest(AmazonBedrockBaseClient awsBedrockClient) { + var chatCompletionResponseListener = new AmazonBedrockChatCompletionResponseListener( + chatCompletionRequest, + responseHandler, + inferenceResultsListener + ); + chatCompletionRequest.executeChatCompletionRequest(awsBedrockClient, chatCompletionResponseListener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java new file mode 100644 index 0000000000000..812e76129c420 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClient.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.ConverseResult; +import com.amazonaws.services.bedrockruntime.model.InvokeModelRequest; +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; + +import java.time.Instant; + +public interface AmazonBedrockClient { + void converse(ConverseRequest converseRequest, ActionListener responseListener) throws ElasticsearchException; + + void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener responseListener) + throws ElasticsearchException; + + boolean isExpired(Instant currentTimestampMs); + + void resetExpiration(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClientCache.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClientCache.java new file mode 100644 index 0000000000000..e6bb99620b581 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockClientCache.java @@ -0,0 +1,19 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; + +import java.io.Closeable; +import java.io.IOException; + +public interface AmazonBedrockClientCache extends Closeable { + AmazonBedrockBaseClient getOrCreateClient(AmazonBedrockModel model, @Nullable TimeValue timeout) throws IOException; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockEmbeddingsExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockEmbeddingsExecutor.java new file mode 100644 index 0000000000000..6da3f86e0909a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockEmbeddingsExecutor.java @@ -0,0 +1,44 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings.AmazonBedrockEmbeddingsResponseListener; + +import java.util.function.Supplier; + +public class AmazonBedrockEmbeddingsExecutor extends AmazonBedrockExecutor { + + private final AmazonBedrockEmbeddingsRequest embeddingsRequest; + + protected AmazonBedrockEmbeddingsExecutor( + AmazonBedrockEmbeddingsRequest request, + AmazonBedrockResponseHandler responseHandler, + Logger logger, + Supplier hasRequestCompletedFunction, + ActionListener inferenceResultsListener, + AmazonBedrockClientCache clientCache + ) { + super(request, responseHandler, logger, hasRequestCompletedFunction, inferenceResultsListener, clientCache); + this.embeddingsRequest = request; + } + + @Override + protected void executeClientRequest(AmazonBedrockBaseClient awsBedrockClient) { + var embeddingsResponseListener = new AmazonBedrockEmbeddingsResponseListener( + embeddingsRequest, + responseHandler, + inferenceResultsListener + ); + embeddingsRequest.executeEmbeddingsRequest(awsBedrockClient, embeddingsResponseListener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecuteOnlyRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecuteOnlyRequestSender.java new file mode 100644 index 0000000000000..a08acab655936 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecuteOnlyRequestSender.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ExceptionsHelper; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; + +import java.io.IOException; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.core.Strings.format; + +/** + * The AWS SDK uses its own internal retrier and timeout values on the client + */ +public class AmazonBedrockExecuteOnlyRequestSender implements RequestSender { + + protected final AmazonBedrockClientCache clientCache; + private final ThrottlerManager throttleManager; + + public AmazonBedrockExecuteOnlyRequestSender(AmazonBedrockClientCache clientCache, ThrottlerManager throttlerManager) { + this.clientCache = Objects.requireNonNull(clientCache); + this.throttleManager = Objects.requireNonNull(throttlerManager); + } + + @Override + public void send( + Logger logger, + Request request, + HttpClientContext context, + Supplier hasRequestTimedOutFunction, + ResponseHandler responseHandler, + ActionListener listener + ) { + if (request instanceof AmazonBedrockRequest awsRequest && responseHandler instanceof AmazonBedrockResponseHandler awsResponse) { + try { + var executor = createExecutor(awsRequest, awsResponse, logger, hasRequestTimedOutFunction, listener); + + // the run method will call the listener to return the proper value + executor.run(); + return; + } catch (Exception e) { + logException(logger, request, e); + listener.onFailure(wrapWithElasticsearchException(e, request.getInferenceEntityId())); + } + } + + listener.onFailure(new ElasticsearchException("Amazon Bedrock request was not the correct type")); + } + + // allow this to be overridden for testing + protected AmazonBedrockExecutor createExecutor( + AmazonBedrockRequest awsRequest, + AmazonBedrockResponseHandler awsResponse, + Logger logger, + Supplier hasRequestTimedOutFunction, + ActionListener listener + ) { + switch (awsRequest.taskType()) { + case COMPLETION -> { + return new AmazonBedrockChatCompletionExecutor( + (AmazonBedrockChatCompletionRequest) awsRequest, + awsResponse, + logger, + hasRequestTimedOutFunction, + listener, + clientCache + ); + } + case TEXT_EMBEDDING -> { + return new AmazonBedrockEmbeddingsExecutor( + (AmazonBedrockEmbeddingsRequest) awsRequest, + awsResponse, + logger, + hasRequestTimedOutFunction, + listener, + clientCache + ); + } + default -> { + throw new UnsupportedOperationException("Unsupported task type [" + awsRequest.taskType() + "] for Amazon Bedrock request"); + } + } + } + + private void logException(Logger logger, Request request, Exception exception) { + var causeException = ExceptionsHelper.unwrapCause(exception); + + throttleManager.warn( + logger, + format("Failed while sending request from inference entity id [%s] of type [amazonbedrock]", request.getInferenceEntityId()), + causeException + ); + } + + private Exception wrapWithElasticsearchException(Exception e, String inferenceEntityId) { + return new ElasticsearchException( + format("Amazon Bedrock client failed to send request from inference entity id [%s]", inferenceEntityId), + e + ); + } + + public void shutdown() throws IOException { + this.clientCache.close(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutor.java new file mode 100644 index 0000000000000..fa220ee5d2831 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutor.java @@ -0,0 +1,68 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; + +import java.util.Objects; +import java.util.function.Supplier; + +public abstract class AmazonBedrockExecutor implements Runnable { + protected final AmazonBedrockModel baseModel; + protected final AmazonBedrockResponseHandler responseHandler; + protected final Logger logger; + protected final AmazonBedrockRequest request; + protected final Supplier hasRequestCompletedFunction; + protected final ActionListener inferenceResultsListener; + protected final AmazonBedrockClientCache clientCache; + + protected AmazonBedrockExecutor( + AmazonBedrockRequest request, + AmazonBedrockResponseHandler responseHandler, + Logger logger, + Supplier hasRequestCompletedFunction, + ActionListener inferenceResultsListener, + AmazonBedrockClientCache clientCache + ) { + this.request = Objects.requireNonNull(request); + this.responseHandler = Objects.requireNonNull(responseHandler); + this.logger = Objects.requireNonNull(logger); + this.hasRequestCompletedFunction = Objects.requireNonNull(hasRequestCompletedFunction); + this.inferenceResultsListener = Objects.requireNonNull(inferenceResultsListener); + this.clientCache = Objects.requireNonNull(clientCache); + this.baseModel = request.model(); + } + + @Override + public void run() { + if (hasRequestCompletedFunction.get()) { + // has already been run + return; + } + + var inferenceEntityId = baseModel.getInferenceEntityId(); + + try { + var awsBedrockClient = clientCache.getOrCreateClient(baseModel, request.timeout()); + executeClientRequest(awsBedrockClient); + } catch (Exception e) { + var errorMessage = Strings.format("Failed to send request from inference entity id [%s]", inferenceEntityId); + logger.warn(errorMessage, e); + inferenceResultsListener.onFailure(new ElasticsearchException(errorMessage, e)); + } + } + + protected abstract void executeClientRequest(AmazonBedrockBaseClient awsBedrockClient); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java new file mode 100644 index 0000000000000..c3d458925268c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClient.java @@ -0,0 +1,166 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import com.amazonaws.ClientConfiguration; +import com.amazonaws.auth.AWSStaticCredentialsProvider; +import com.amazonaws.auth.BasicAWSCredentials; +import com.amazonaws.services.bedrockruntime.AmazonBedrockRuntimeAsync; +import com.amazonaws.services.bedrockruntime.AmazonBedrockRuntimeAsyncClientBuilder; +import com.amazonaws.services.bedrockruntime.model.AmazonBedrockRuntimeException; +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.ConverseResult; +import com.amazonaws.services.bedrockruntime.model.InvokeModelRequest; +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.SpecialPermission; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.core.common.socket.SocketAccess; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; + +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.time.Duration; +import java.time.Instant; +import java.util.Objects; + +/** + * Not marking this as "final" so we can subclass it for mocking + */ +public class AmazonBedrockInferenceClient extends AmazonBedrockBaseClient { + + // package-private for testing + static final int CLIENT_CACHE_EXPIRY_MINUTES = 5; + private static final int DEFAULT_CLIENT_TIMEOUT_MS = 10000; + + private final AmazonBedrockRuntimeAsync internalClient; + private volatile Instant expiryTimestamp; + + public static AmazonBedrockBaseClient create(AmazonBedrockModel model, @Nullable TimeValue timeout) { + try { + return new AmazonBedrockInferenceClient(model, timeout); + } catch (Exception e) { + throw new ElasticsearchException("Failed to create Amazon Bedrock Client", e); + } + } + + protected AmazonBedrockInferenceClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + super(model, timeout); + this.internalClient = createAmazonBedrockClient(model, timeout); + setExpiryTimestamp(); + } + + @Override + public void converse(ConverseRequest converseRequest, ActionListener responseListener) throws ElasticsearchException { + try { + var responseFuture = internalClient.converseAsync(converseRequest); + responseListener.onResponse(responseFuture.get()); + } catch (AmazonBedrockRuntimeException amazonBedrockRuntimeException) { + responseListener.onFailure( + new ElasticsearchException( + Strings.format("AmazonBedrock converse failure: [%s]", amazonBedrockRuntimeException.getMessage()), + amazonBedrockRuntimeException + ) + ); + } catch (ElasticsearchException elasticsearchException) { + // just throw the exception if we have one + responseListener.onFailure(elasticsearchException); + } catch (Exception e) { + responseListener.onFailure(new ElasticsearchException("Amazon Bedrock client converse call failed", e)); + } + } + + @Override + public void invokeModel(InvokeModelRequest invokeModelRequest, ActionListener responseListener) + throws ElasticsearchException { + try { + var responseFuture = internalClient.invokeModelAsync(invokeModelRequest); + responseListener.onResponse(responseFuture.get()); + } catch (AmazonBedrockRuntimeException amazonBedrockRuntimeException) { + responseListener.onFailure( + new ElasticsearchException( + Strings.format("AmazonBedrock invoke model failure: [%s]", amazonBedrockRuntimeException.getMessage()), + amazonBedrockRuntimeException + ) + ); + } catch (ElasticsearchException elasticsearchException) { + // just throw the exception if we have one + responseListener.onFailure(elasticsearchException); + } catch (Exception e) { + responseListener.onFailure(new ElasticsearchException(e)); + } + } + + // allow this to be overridden for test mocks + protected AmazonBedrockRuntimeAsync createAmazonBedrockClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + var secretSettings = model.getSecretSettings(); + var credentials = new BasicAWSCredentials(secretSettings.accessKey.toString(), secretSettings.secretKey.toString()); + var credentialsProvider = new AWSStaticCredentialsProvider(credentials); + var clientConfig = timeout == null + ? new ClientConfiguration().withConnectionTimeout(DEFAULT_CLIENT_TIMEOUT_MS) + : new ClientConfiguration().withConnectionTimeout((int) timeout.millis()); + + var serviceSettings = model.getServiceSettings(); + + try { + SpecialPermission.check(); + AmazonBedrockRuntimeAsyncClientBuilder builder = AccessController.doPrivileged( + (PrivilegedExceptionAction) () -> AmazonBedrockRuntimeAsyncClientBuilder.standard() + .withCredentials(credentialsProvider) + .withRegion(serviceSettings.region()) + .withClientConfiguration(clientConfig) + ); + + return SocketAccess.doPrivileged(builder::build); + } catch (AmazonBedrockRuntimeException amazonBedrockRuntimeException) { + throw new ElasticsearchException( + Strings.format("failed to create AmazonBedrockRuntime client: [%s]", amazonBedrockRuntimeException.getMessage()), + amazonBedrockRuntimeException + ); + } catch (Exception e) { + throw new ElasticsearchException("failed to create AmazonBedrockRuntime client", e); + } + } + + private void setExpiryTimestamp() { + this.expiryTimestamp = clock.instant().plus(Duration.ofMinutes(CLIENT_CACHE_EXPIRY_MINUTES)); + } + + @Override + public boolean isExpired(Instant currentTimestampMs) { + Objects.requireNonNull(currentTimestampMs); + return currentTimestampMs.isAfter(expiryTimestamp); + } + + public void resetExpiration() { + setExpiryTimestamp(); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AmazonBedrockInferenceClient that = (AmazonBedrockInferenceClient) o; + return Objects.equals(modelKeysAndRegionHashcode, that.modelKeysAndRegionHashcode); + } + + @Override + public int hashCode() { + return this.modelKeysAndRegionHashcode; + } + + // make this package-private so only the cache can close it + @Override + void close() { + internalClient.shutdown(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java new file mode 100644 index 0000000000000..e245365c214af --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCache.java @@ -0,0 +1,137 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import com.amazonaws.http.IdleConnectionReaper; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; + +import java.io.IOException; +import java.time.Clock; +import java.util.ArrayList; +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.BiFunction; + +public final class AmazonBedrockInferenceClientCache implements AmazonBedrockClientCache { + + private final BiFunction creator; + private final Map clientsCache = new ConcurrentHashMap<>(); + private final ReentrantReadWriteLock cacheLock = new ReentrantReadWriteLock(); + + // not final for testing + private Clock clock; + + public AmazonBedrockInferenceClientCache( + BiFunction creator, + @Nullable Clock clock + ) { + this.creator = Objects.requireNonNull(creator); + this.clock = Objects.requireNonNullElse(clock, Clock.systemUTC()); + } + + public AmazonBedrockBaseClient getOrCreateClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + var returnClient = internalGetOrCreateClient(model, timeout); + flushExpiredClients(); + return returnClient; + } + + private AmazonBedrockBaseClient internalGetOrCreateClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + final Integer modelHash = AmazonBedrockInferenceClient.getModelKeysAndRegionHashcode(model, timeout); + cacheLock.readLock().lock(); + try { + return clientsCache.computeIfAbsent(modelHash, hashKey -> { + final AmazonBedrockBaseClient builtClient = creator.apply(model, timeout); + builtClient.setClock(clock); + builtClient.resetExpiration(); + return builtClient; + }); + } finally { + cacheLock.readLock().unlock(); + } + } + + private void flushExpiredClients() { + var currentTimestampMs = clock.instant(); + var expiredClients = new ArrayList>(); + + cacheLock.readLock().lock(); + try { + for (final Map.Entry client : clientsCache.entrySet()) { + if (client.getValue().isExpired(currentTimestampMs)) { + expiredClients.add(client); + } + } + + if (expiredClients.isEmpty()) { + return; + } + + cacheLock.readLock().unlock(); + cacheLock.writeLock().lock(); + try { + for (final Map.Entry client : expiredClients) { + var removed = clientsCache.remove(client.getKey()); + if (removed != null) { + removed.close(); + } + } + } finally { + cacheLock.readLock().lock(); + cacheLock.writeLock().unlock(); + } + } finally { + cacheLock.readLock().unlock(); + } + } + + @Override + public void close() throws IOException { + releaseCachedClients(); + } + + private void releaseCachedClients() { + // as we're closing and flushing all of these - we'll use a write lock + // across the whole operation to ensure this stays in sync + cacheLock.writeLock().lock(); + try { + // ensure all the clients are closed before we clear + for (final AmazonBedrockBaseClient client : clientsCache.values()) { + client.close(); + } + + // clear previously cached clients, they will be build lazily + clientsCache.clear(); + } finally { + cacheLock.writeLock().unlock(); + } + + // shutdown IdleConnectionReaper background thread + // it will be restarted on new client usage + IdleConnectionReaper.shutdown(); + } + + // used for testing + int clientCount() { + cacheLock.readLock().lock(); + try { + return clientsCache.size(); + } finally { + cacheLock.readLock().unlock(); + } + } + + // used for testing + void setClock(Clock newClock) { + this.clock = Objects.requireNonNull(newClock); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java new file mode 100644 index 0000000000000..e23b0274ede26 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSender.java @@ -0,0 +1,126 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockRequestExecutorService; +import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; +import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; + +import java.io.IOException; +import java.util.Objects; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; + +public class AmazonBedrockRequestSender implements Sender { + + public static class Factory { + private final ServiceComponents serviceComponents; + private final ClusterService clusterService; + + public Factory(ServiceComponents serviceComponents, ClusterService clusterService) { + this.serviceComponents = Objects.requireNonNull(serviceComponents); + this.clusterService = Objects.requireNonNull(clusterService); + } + + public Sender createSender() { + var clientCache = new AmazonBedrockInferenceClientCache(AmazonBedrockInferenceClient::create, null); + return createSender(new AmazonBedrockExecuteOnlyRequestSender(clientCache, serviceComponents.throttlerManager())); + } + + Sender createSender(AmazonBedrockExecuteOnlyRequestSender requestSender) { + var sender = new AmazonBedrockRequestSender( + serviceComponents.threadPool(), + clusterService, + serviceComponents.settings(), + Objects.requireNonNull(requestSender) + ); + // ensure this is started + sender.start(); + return sender; + } + } + + private static final TimeValue START_COMPLETED_WAIT_TIME = TimeValue.timeValueSeconds(5); + + private final ThreadPool threadPool; + private final AmazonBedrockRequestExecutorService executorService; + private final AtomicBoolean started = new AtomicBoolean(false); + private final CountDownLatch startCompleted = new CountDownLatch(1); + + protected AmazonBedrockRequestSender( + ThreadPool threadPool, + ClusterService clusterService, + Settings settings, + AmazonBedrockExecuteOnlyRequestSender requestSender + ) { + this.threadPool = Objects.requireNonNull(threadPool); + executorService = new AmazonBedrockRequestExecutorService( + threadPool, + startCompleted, + new RequestExecutorServiceSettings(settings, clusterService), + requestSender + ); + } + + @Override + public void start() { + if (started.compareAndSet(false, true)) { + // The manager must be started before the executor service. That way we guarantee that the http client + // is ready prior to the service attempting to use the http client to send a request + threadPool.executor(UTILITY_THREAD_POOL_NAME).execute(executorService::start); + waitForStartToComplete(); + } + } + + private void waitForStartToComplete() { + try { + if (startCompleted.await(START_COMPLETED_WAIT_TIME.getSeconds(), TimeUnit.SECONDS) == false) { + throw new IllegalStateException("Amazon Bedrock sender startup did not complete in time"); + } + } catch (InterruptedException e) { + throw new IllegalStateException("Amazon Bedrock sender interrupted while waiting for startup to complete"); + } + } + + @Override + public void send( + RequestManager requestCreator, + InferenceInputs inferenceInputs, + TimeValue timeout, + ActionListener listener + ) { + assert started.get() : "Amazon Bedrock request sender: call start() before sending a request"; + waitForStartToComplete(); + + if (requestCreator instanceof AmazonBedrockRequestManager amazonBedrockRequestManager) { + executorService.execute(amazonBedrockRequestManager, inferenceInputs, timeout, listener); + return; + } + + listener.onFailure(new ElasticsearchException("Amazon Bedrock request sender did not receive a valid request request manager")); + } + + @Override + public void close() throws IOException { + executorService.shutdown(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java new file mode 100644 index 0000000000000..8642a19b26a7d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java @@ -0,0 +1,65 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionEntityFactory; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion.AmazonBedrockChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; + +import java.util.List; +import java.util.function.Supplier; + +public class AmazonBedrockChatCompletionRequestManager extends AmazonBedrockRequestManager { + private static final Logger logger = LogManager.getLogger(AmazonBedrockChatCompletionRequestManager.class); + private final AmazonBedrockChatCompletionModel model; + + public AmazonBedrockChatCompletionRequestManager( + AmazonBedrockChatCompletionModel model, + ThreadPool threadPool, + @Nullable TimeValue timeout + ) { + super(model, threadPool, timeout); + this.model = model; + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, docsInput); + var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout); + var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); + + try { + requestSender.send(logger, request, HttpClientContext.create(), hasRequestCompletedFunction, responseHandler, listener); + } catch (Exception e) { + var errorMessage = Strings.format( + "Failed to send [completion] request from inference entity id [%s]", + request.getInferenceEntityId() + ); + logger.warn(errorMessage, e); + listener.onFailure(new ElasticsearchException(errorMessage, e)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java new file mode 100644 index 0000000000000..2f94cdf342938 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockEmbeddingsRequestManager.java @@ -0,0 +1,74 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockEmbeddingsEntityFactory; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings.AmazonBedrockEmbeddingsResponseHandler; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +public class AmazonBedrockEmbeddingsRequestManager extends AmazonBedrockRequestManager { + private static final Logger logger = LogManager.getLogger(AmazonBedrockEmbeddingsRequestManager.class); + + private final AmazonBedrockEmbeddingsModel embeddingsModel; + private final Truncator truncator; + + public AmazonBedrockEmbeddingsRequestManager( + AmazonBedrockEmbeddingsModel model, + Truncator truncator, + ThreadPool threadPool, + @Nullable TimeValue timeout + ) { + super(model, threadPool, timeout); + this.embeddingsModel = model; + this.truncator = Objects.requireNonNull(truncator); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var serviceSettings = embeddingsModel.getServiceSettings(); + var truncatedInput = truncate(docsInput, serviceSettings.maxInputTokens()); + var requestEntity = AmazonBedrockEmbeddingsEntityFactory.createEntity(embeddingsModel, truncatedInput); + var responseHandler = new AmazonBedrockEmbeddingsResponseHandler(); + var request = new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout); + try { + requestSender.send(logger, request, HttpClientContext.create(), hasRequestCompletedFunction, responseHandler, listener); + } catch (Exception e) { + var errorMessage = Strings.format( + "Failed to send [text_embedding] request from inference entity id [%s]", + request.getInferenceEntityId() + ); + logger.warn(errorMessage, e); + listener.onFailure(new ElasticsearchException(errorMessage, e)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockRequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockRequestExecutorService.java new file mode 100644 index 0000000000000..8b4672d45c250 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockRequestExecutorService.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockExecuteOnlyRequestSender; + +import java.io.IOException; +import java.util.concurrent.CountDownLatch; + +/** + * Allows this to have a public interface for Amazon Bedrock support + */ +public class AmazonBedrockRequestExecutorService extends RequestExecutorService { + + private final AmazonBedrockExecuteOnlyRequestSender requestSender; + + public AmazonBedrockRequestExecutorService( + ThreadPool threadPool, + CountDownLatch startupLatch, + RequestExecutorServiceSettings settings, + AmazonBedrockExecuteOnlyRequestSender requestSender + ) { + super(threadPool, startupLatch, settings, requestSender); + this.requestSender = requestSender; + } + + @Override + public void shutdown() { + super.shutdown(); + try { + requestSender.shutdown(); + } catch (IOException e) { + // swallow the exception + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockRequestManager.java new file mode 100644 index 0000000000000..f75343b038368 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockRequestManager.java @@ -0,0 +1,54 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.Objects; + +public abstract class AmazonBedrockRequestManager implements RequestManager { + + protected final ThreadPool threadPool; + protected final TimeValue timeout; + private final AmazonBedrockModel baseModel; + + protected AmazonBedrockRequestManager(AmazonBedrockModel baseModel, ThreadPool threadPool, @Nullable TimeValue timeout) { + this.baseModel = Objects.requireNonNull(baseModel); + this.threadPool = Objects.requireNonNull(threadPool); + this.timeout = timeout; + } + + @Override + public String inferenceEntityId() { + return baseModel.getInferenceEntityId(); + } + + @Override + public RateLimitSettings rateLimitSettings() { + return baseModel.rateLimitSettings(); + } + + record RateLimitGrouping(int keyHash) { + public static AmazonBedrockRequestManager.RateLimitGrouping of(AmazonBedrockModel model) { + Objects.requireNonNull(model); + + var awsSecretSettings = model.getSecretSettings(); + + return new RateLimitGrouping(Objects.hash(awsSecretSettings.accessKey, awsSecretSettings.secretKey)); + } + } + + @Override + public Object rateLimitGrouping() { + return RateLimitGrouping.of(this.baseModel); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java index 7dd1a66db13e7..7c527bbd2ee98 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.anthropic.AnthropicResponseHandler; @@ -43,13 +42,13 @@ private AnthropicCompletionRequestManager(AnthropicChatCompletionModel model, Th @Override public void execute( - @Nullable String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(input, model); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(docsInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java index e295cf5cc43dd..c5e5a5251f7db 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java @@ -37,13 +37,13 @@ public AzureAiStudioChatCompletionRequestManager(AzureAiStudioChatCompletionMode @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, input); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, docsInput); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java index f0f87402fb3a5..c610a7f31f7ba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java @@ -41,13 +41,13 @@ public AzureAiStudioEmbeddingsRequestManager(AzureAiStudioEmbeddingsModel model, @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); AzureAiStudioEmbeddingsRequest request = new AzureAiStudioEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java index 5206d6c2c23cc..8c9b848f78e3c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.azureopenai.AzureOpenAiResponseHandler; @@ -43,13 +42,13 @@ public AzureOpenAiCompletionRequestManager(AzureOpenAiCompletionModel model, Thr @Override public void execute( - @Nullable String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(input, model); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(docsInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java index e0fcee30e5af3..8d4162858b36f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiEmbeddingsRequestManager.java @@ -55,13 +55,14 @@ public AzureOpenAiEmbeddingsRequestManager(AzureOpenAiEmbeddingsModel model, Tru @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); + AzureOpenAiEmbeddingsRequest request = new AzureOpenAiEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java index 8a4b0e45b93fa..423093a14a9f0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java @@ -46,13 +46,13 @@ private CohereCompletionRequestManager(CohereCompletionModel model, ThreadPool t @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - CohereCompletionRequest request = new CohereCompletionRequest(input, model); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + CohereCompletionRequest request = new CohereCompletionRequest(docsInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java index a51910f1d0a67..402f91a0838dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereEmbeddingsRequestManager.java @@ -44,13 +44,13 @@ private CohereEmbeddingsRequestManager(CohereEmbeddingsModel model, ThreadPool t @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(input, model); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + CohereEmbeddingsRequest request = new CohereEmbeddingsRequest(docsInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java index 1351eec406569..9d565e7124b03 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereRerankRequestManager.java @@ -19,7 +19,6 @@ import org.elasticsearch.xpack.inference.external.response.cohere.CohereRankedResponseEntity; import org.elasticsearch.xpack.inference.services.cohere.rerank.CohereRerankModel; -import java.util.List; import java.util.Objects; import java.util.function.Supplier; @@ -44,13 +43,13 @@ private CohereRerankRequestManager(CohereRerankModel model, ThreadPool threadPoo @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - CohereRerankRequest request = new CohereRerankRequest(query, input, model); + var rerankInput = QueryAndDocsInputs.of(inferenceInputs); + CohereRerankRequest request = new CohereRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java index a11be003585fd..a32e2018117f8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java @@ -12,7 +12,15 @@ public class DocumentsOnlyInput extends InferenceInputs { - List input; + public static DocumentsOnlyInput of(InferenceInputs inferenceInputs) { + if (inferenceInputs instanceof DocumentsOnlyInput == false) { + throw createUnsupportedTypeException(inferenceInputs); + } + + return (DocumentsOnlyInput) inferenceInputs; + } + + private final List input; public DocumentsOnlyInput(List chunks) { super(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java index 2b191b046477b..426102f7f2376 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java @@ -42,13 +42,13 @@ public GoogleAiStudioCompletionRequestManager(GoogleAiStudioCompletionModel mode @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(input, model); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(docsInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java index 6436e0231ab48..c7f87fb1cbf7f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java @@ -48,13 +48,13 @@ public GoogleAiStudioEmbeddingsRequestManager(GoogleAiStudioEmbeddingsModel mode @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); GoogleAiStudioEmbeddingsRequest request = new GoogleAiStudioEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java index c682da9a1694a..94f44c64b04da 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiEmbeddingsRequestManager.java @@ -56,13 +56,13 @@ public static RateLimitGrouping of(GoogleVertexAiEmbeddingsModel model) { @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); var request = new GoogleVertexAiEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java index ab49ecc7ab9f9..e74f0049fffb0 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleVertexAiRerankRequestManager.java @@ -19,7 +19,6 @@ import org.elasticsearch.xpack.inference.external.response.googlevertexai.GoogleVertexAiRerankResponseEntity; import org.elasticsearch.xpack.inference.services.googlevertexai.rerank.GoogleVertexAiRerankModel; -import java.util.List; import java.util.Objects; import java.util.function.Supplier; @@ -57,13 +56,13 @@ public static RateLimitGrouping of(GoogleVertexAiRerankModel model) { @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest(query, input, model); + var rerankInput = QueryAndDocsInputs.of(inferenceInputs); + GoogleVertexAiRerankRequest request = new GoogleVertexAiRerankRequest(rerankInput.getQuery(), rerankInput.getChunks(), model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java index 6c8fc446d5243..a33eb724551f1 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/HuggingFaceRequestManager.java @@ -55,13 +55,13 @@ private HuggingFaceRequestManager(HuggingFaceModel model, ResponseHandler respon @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - var truncatedInput = truncate(input, model.getTokenLimit()); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var truncatedInput = truncate(docsInput, model.getTokenLimit()); var request = new HuggingFaceInferenceRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, responseHandler, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java index d7e07e734ce80..dd241857ef0c4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java @@ -7,4 +7,10 @@ package org.elasticsearch.xpack.inference.external.http.sender; -public abstract class InferenceInputs {} +import org.elasticsearch.common.Strings; + +public abstract class InferenceInputs { + public static IllegalArgumentException createUnsupportedTypeException(InferenceInputs inferenceInputs) { + return new IllegalArgumentException(Strings.format("Unsupported inference inputs type: [%s]", inferenceInputs.getClass())); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java index 6199a75a41a7d..52be5d8be2b6f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceRequest.java @@ -10,7 +10,6 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.inference.InferenceServiceResults; -import java.util.List; import java.util.function.Supplier; /** @@ -24,14 +23,9 @@ public interface InferenceRequest { RequestManager getRequestManager(); /** - * Returns the query associated with this request. Used for Rerank tasks. + * Returns the inputs associated with the request. */ - String getQuery(); - - /** - * Returns the text input associated with this request. - */ - List getInput(); + InferenceInputs getInferenceInputs(); /** * Returns the listener to notify of the results. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java index 1807712a31ac5..d550749cc2348 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java @@ -51,13 +51,13 @@ public MistralEmbeddingsRequestManager(MistralEmbeddingsModel model, Truncator t @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java index 7bc09fd76736b..65f25c0baf8dc 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java @@ -10,7 +10,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; @@ -43,13 +42,13 @@ private OpenAiCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPo @Override public void execute( - @Nullable String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(input, model); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(docsInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java index 41f91d2b89ee5..5c164f2eb9644 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManager.java @@ -55,13 +55,13 @@ private OpenAiEmbeddingsRequestManager(OpenAiEmbeddingsModel model, Truncator tr @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener ) { - var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + List docsInput = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + var truncatedInput = truncate(docsInput, model.getServiceSettings().maxInputTokens()); OpenAiEmbeddingsRequest request = new OpenAiEmbeddingsRequest(truncator, truncatedInput, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java index 4d24598d67831..0d5f98c180ba9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java @@ -12,7 +12,15 @@ public class QueryAndDocsInputs extends InferenceInputs { - String query; + public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) { + if (inferenceInputs instanceof QueryAndDocsInputs == false) { + throw createUnsupportedTypeException(inferenceInputs); + } + + return (QueryAndDocsInputs) inferenceInputs; + } + + private final String query; public String getQuery() { return query; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java index 38d47aec68eb6..ad1324d0a315f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorService.java @@ -413,7 +413,7 @@ private TimeValue executeEnqueuedTaskInternal() { assert shouldExecuteImmediately(reserveRes) : "Reserving request tokens required a sleep when it should not have"; task.getRequestManager() - .execute(task.getQuery(), task.getInput(), requestSender, task.getRequestCompletedFunction(), task.getListener()); + .execute(task.getInferenceInputs(), requestSender, task.getRequestCompletedFunction(), task.getListener()); return EXECUTED_A_TASK; } @@ -423,7 +423,7 @@ private static boolean shouldExecuteTask(RejectableTask task) { private static boolean isNoopRequest(InferenceRequest inferenceRequest) { return inferenceRequest.getRequestManager() == null - || inferenceRequest.getInput() == null + || inferenceRequest.getInferenceInputs() == null || inferenceRequest.getListener() == null; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java index 79ef1b56ad231..853d6fdcb2473 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManager.java @@ -8,12 +8,10 @@ package org.elasticsearch.xpack.inference.external.http.sender; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.ratelimit.RateLimitable; -import java.util.List; import java.util.function.Supplier; /** @@ -21,8 +19,7 @@ */ public interface RequestManager extends RateLimitable { void execute( - @Nullable String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java index 7a5f482412289..9ccb93a0858ae 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTask.java @@ -16,7 +16,6 @@ import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.threadpool.ThreadPool; -import java.util.List; import java.util.Objects; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.Supplier; @@ -27,8 +26,7 @@ class RequestTask implements RejectableTask { private final AtomicBoolean finished = new AtomicBoolean(); private final RequestManager requestCreator; - private final String query; - private final List input; + private final InferenceInputs inferenceInputs; private final ActionListener listener; RequestTask( @@ -40,16 +38,7 @@ class RequestTask implements RejectableTask { ) { this.requestCreator = Objects.requireNonNull(requestCreator); this.listener = getListener(Objects.requireNonNull(listener), timeout, Objects.requireNonNull(threadPool)); - - if (inferenceInputs instanceof QueryAndDocsInputs) { - this.query = ((QueryAndDocsInputs) inferenceInputs).getQuery(); - this.input = ((QueryAndDocsInputs) inferenceInputs).getChunks(); - } else if (inferenceInputs instanceof DocumentsOnlyInput) { - this.query = null; - this.input = ((DocumentsOnlyInput) inferenceInputs).getInputs(); - } else { - throw new IllegalArgumentException("Unsupported inference inputs type: " + inferenceInputs.getClass()); - } + this.inferenceInputs = Objects.requireNonNull(inferenceInputs); } private ActionListener getListener( @@ -91,13 +80,8 @@ public Supplier getRequestCompletedFunction() { } @Override - public List getInput() { - return input; - } - - @Override - public String getQuery() { - return query; + public InferenceInputs getInferenceInputs() { + return inferenceInputs; } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonBuilder.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonBuilder.java new file mode 100644 index 0000000000000..829e899beba5e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonBuilder.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.ToXContent; + +import java.io.IOException; + +import static org.elasticsearch.xcontent.XContentFactory.jsonBuilder; + +public class AmazonBedrockJsonBuilder { + + private final ToXContent jsonWriter; + + public AmazonBedrockJsonBuilder(ToXContent jsonWriter) { + this.jsonWriter = jsonWriter; + } + + public String getStringContent() throws IOException { + try (var builder = jsonBuilder()) { + return Strings.toString(jsonWriter.toXContent(builder, null)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonWriter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonWriter.java new file mode 100644 index 0000000000000..83ebcb4563a8c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockJsonWriter.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock; + +import com.fasterxml.jackson.core.JsonGenerator; + +import java.io.IOException; + +/** + * This is needed as the input for the Amazon Bedrock SDK does not like + * the formatting of XContent JSON output + */ +public interface AmazonBedrockJsonWriter { + JsonGenerator writeJson(JsonGenerator generator) throws IOException; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockRequest.java new file mode 100644 index 0000000000000..e356212ed07fb --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/AmazonBedrockRequest.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockBaseClient; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; + +import java.net.URI; + +public abstract class AmazonBedrockRequest implements Request { + + protected final AmazonBedrockModel amazonBedrockModel; + protected final String inferenceId; + protected final TimeValue timeout; + + protected AmazonBedrockRequest(AmazonBedrockModel model, @Nullable TimeValue timeout) { + this.amazonBedrockModel = model; + this.inferenceId = model.getInferenceEntityId(); + this.timeout = timeout; + } + + protected abstract void executeRequest(AmazonBedrockBaseClient client); + + public AmazonBedrockModel model() { + return amazonBedrockModel; + } + + /** + * Amazon Bedrock uses the AWS SDK, and will not create its own Http Request + * But, this is needed for the ExecutableInferenceRequest to get the inferenceEntityId + * @return NoOp request + */ + @Override + public final HttpRequest createHttpRequest() { + return new HttpRequest(new NoOpHttpRequest(), inferenceId); + } + + /** + * Amazon Bedrock uses the AWS SDK, and will not create its own URI + * @return null + */ + @Override + public final URI getURI() { + throw new UnsupportedOperationException(); + } + + /** + * Should be overridden for text embeddings requests + * @return null + */ + @Override + public Request truncate() { + return this; + } + + /** + * Should be overridden for text embeddings requests + * @return boolean[0] + */ + @Override + public boolean[] getTruncationInfo() { + return new boolean[0]; + } + + @Override + public String getInferenceEntityId() { + return amazonBedrockModel.getInferenceEntityId(); + } + + public TimeValue timeout() { + return timeout; + } + + public abstract TaskType taskType(); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/NoOpHttpRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/NoOpHttpRequest.java new file mode 100644 index 0000000000000..7087bb03bca5e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/NoOpHttpRequest.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock; + +import org.apache.http.client.methods.HttpRequestBase; + +/** + * Needed for compatibility with RequestSender + */ +public class NoOpHttpRequest extends HttpRequestBase { + @Override + public String getMethod() { + return "NOOP"; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntity.java new file mode 100644 index 0000000000000..6e2f2f6702005 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntity.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; + +import org.elasticsearch.core.Nullable; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; + +public record AmazonBedrockAI21LabsCompletionRequestEntity( + List messages, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Integer maxTokenCount +) implements AmazonBedrockConverseRequestEntity { + + public AmazonBedrockAI21LabsCompletionRequestEntity { + Objects.requireNonNull(messages); + } + + @Override + public ConverseRequest addMessages(ConverseRequest request) { + return request.withMessages(getConverseMessageList(messages)); + } + + @Override + public ConverseRequest addInferenceConfig(ConverseRequest request) { + if (temperature == null && topP == null && maxTokenCount == null) { + return request; + } + + InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + + if (temperature != null) { + inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); + } + + if (topP != null) { + inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); + } + + if (maxTokenCount != null) { + inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); + } + + return request.withInferenceConfig(inferenceConfig); + } + + @Override + public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + return request; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntity.java new file mode 100644 index 0000000000000..a8b0032af09c5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntity.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; + +public record AmazonBedrockAnthropicCompletionRequestEntity( + List messages, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Double topK, + @Nullable Integer maxTokenCount +) implements AmazonBedrockConverseRequestEntity { + + public AmazonBedrockAnthropicCompletionRequestEntity { + Objects.requireNonNull(messages); + } + + @Override + public ConverseRequest addMessages(ConverseRequest request) { + return request.withMessages(getConverseMessageList(messages)); + } + + @Override + public ConverseRequest addInferenceConfig(ConverseRequest request) { + if (temperature == null && topP == null && maxTokenCount == null) { + return request; + } + + InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + + if (temperature != null) { + inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); + } + + if (topP != null) { + inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); + } + + if (maxTokenCount != null) { + inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); + } + + return request.withInferenceConfig(inferenceConfig); + } + + @Override + public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + if (topK == null) { + return request; + } + + String topKField = Strings.format("{\"top_k\":%f}", topK.floatValue()); + return request.withAdditionalModelResponseFieldPaths(topKField); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionEntityFactory.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionEntityFactory.java new file mode 100644 index 0000000000000..f86d2229d42ad --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionEntityFactory.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; + +import java.util.List; +import java.util.Objects; + +public final class AmazonBedrockChatCompletionEntityFactory { + public static AmazonBedrockConverseRequestEntity createEntity(AmazonBedrockChatCompletionModel model, List messages) { + Objects.requireNonNull(model); + Objects.requireNonNull(messages); + var serviceSettings = model.getServiceSettings(); + var taskSettings = model.getTaskSettings(); + switch (serviceSettings.provider()) { + case AI21LABS -> { + return new AmazonBedrockAI21LabsCompletionRequestEntity( + messages, + taskSettings.temperature(), + taskSettings.topP(), + taskSettings.maxNewTokens() + ); + } + case AMAZONTITAN -> { + return new AmazonBedrockTitanCompletionRequestEntity( + messages, + taskSettings.temperature(), + taskSettings.topP(), + taskSettings.maxNewTokens() + ); + } + case ANTHROPIC -> { + return new AmazonBedrockAnthropicCompletionRequestEntity( + messages, + taskSettings.temperature(), + taskSettings.topP(), + taskSettings.topK(), + taskSettings.maxNewTokens() + ); + } + case COHERE -> { + return new AmazonBedrockCohereCompletionRequestEntity( + messages, + taskSettings.temperature(), + taskSettings.topP(), + taskSettings.topK(), + taskSettings.maxNewTokens() + ); + } + case META -> { + return new AmazonBedrockMetaCompletionRequestEntity( + messages, + taskSettings.temperature(), + taskSettings.topP(), + taskSettings.maxNewTokens() + ); + } + case MISTRAL -> { + return new AmazonBedrockMistralCompletionRequestEntity( + messages, + taskSettings.temperature(), + taskSettings.topP(), + taskSettings.topK(), + taskSettings.maxNewTokens() + ); + } + default -> { + return null; + } + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java new file mode 100644 index 0000000000000..f02f05f2d3b17 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockChatCompletionRequest.java @@ -0,0 +1,69 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.core.common.socket.SocketAccess; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockBaseClient; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion.AmazonBedrockChatCompletionResponseListener; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; + +import java.io.IOException; +import java.util.Objects; + +public class AmazonBedrockChatCompletionRequest extends AmazonBedrockRequest { + public static final String USER_ROLE = "user"; + private final AmazonBedrockConverseRequestEntity requestEntity; + private AmazonBedrockChatCompletionResponseListener listener; + + public AmazonBedrockChatCompletionRequest( + AmazonBedrockChatCompletionModel model, + AmazonBedrockConverseRequestEntity requestEntity, + @Nullable TimeValue timeout + ) { + super(model, timeout); + this.requestEntity = Objects.requireNonNull(requestEntity); + } + + @Override + protected void executeRequest(AmazonBedrockBaseClient client) { + var converseRequest = getConverseRequest(); + + try { + SocketAccess.doPrivileged(() -> client.converse(converseRequest, listener)); + } catch (IOException e) { + listener.onFailure(new RuntimeException(e)); + } + } + + @Override + public TaskType taskType() { + return TaskType.COMPLETION; + } + + private ConverseRequest getConverseRequest() { + var converseRequest = new ConverseRequest().withModelId(amazonBedrockModel.model()); + converseRequest = requestEntity.addMessages(converseRequest); + converseRequest = requestEntity.addInferenceConfig(converseRequest); + converseRequest = requestEntity.addAdditionalModelFields(converseRequest); + return converseRequest; + } + + public void executeChatCompletionRequest( + AmazonBedrockBaseClient awsBedrockClient, + AmazonBedrockChatCompletionResponseListener chatCompletionResponseListener + ) { + this.listener = chatCompletionResponseListener; + this.executeRequest(awsBedrockClient); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntity.java new file mode 100644 index 0000000000000..17a264ef820ff --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntity.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; + +public record AmazonBedrockCohereCompletionRequestEntity( + List messages, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Double topK, + @Nullable Integer maxTokenCount +) implements AmazonBedrockConverseRequestEntity { + + public AmazonBedrockCohereCompletionRequestEntity { + Objects.requireNonNull(messages); + } + + @Override + public ConverseRequest addMessages(ConverseRequest request) { + return request.withMessages(getConverseMessageList(messages)); + } + + @Override + public ConverseRequest addInferenceConfig(ConverseRequest request) { + if (temperature == null && topP == null && maxTokenCount == null) { + return request; + } + + InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + + if (temperature != null) { + inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); + } + + if (topP != null) { + inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); + } + + if (maxTokenCount != null) { + inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); + } + + return request.withInferenceConfig(inferenceConfig); + } + + @Override + public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + if (topK == null) { + return request; + } + + String topKField = Strings.format("{\"top_k\":%f}", topK.floatValue()); + return request.withAdditionalModelResponseFieldPaths(topKField); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestEntity.java new file mode 100644 index 0000000000000..fbd55e76e509b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestEntity.java @@ -0,0 +1,18 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; + +public interface AmazonBedrockConverseRequestEntity { + ConverseRequest addMessages(ConverseRequest request); + + ConverseRequest addInferenceConfig(ConverseRequest request); + + ConverseRequest addAdditionalModelFields(ConverseRequest request); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseUtils.java new file mode 100644 index 0000000000000..2cfb56a94b319 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseUtils.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ContentBlock; +import com.amazonaws.services.bedrockruntime.model.Message; + +import java.util.ArrayList; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest.USER_ROLE; + +public final class AmazonBedrockConverseUtils { + + public static List getConverseMessageList(List messages) { + List messageList = new ArrayList<>(); + for (String message : messages) { + var messageContent = new ContentBlock().withText(message); + var returnMessage = (new Message()).withRole(USER_ROLE).withContent(messageContent); + messageList.add(returnMessage); + } + return messageList; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntity.java new file mode 100644 index 0000000000000..cdabdd4cbebff --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntity.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; + +import org.elasticsearch.core.Nullable; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; + +public record AmazonBedrockMetaCompletionRequestEntity( + List messages, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Integer maxTokenCount +) implements AmazonBedrockConverseRequestEntity { + + public AmazonBedrockMetaCompletionRequestEntity { + Objects.requireNonNull(messages); + } + + @Override + public ConverseRequest addMessages(ConverseRequest request) { + return request.withMessages(getConverseMessageList(messages)); + } + + @Override + public ConverseRequest addInferenceConfig(ConverseRequest request) { + if (temperature == null && topP == null && maxTokenCount == null) { + return request; + } + + InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + + if (temperature != null) { + inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); + } + + if (topP != null) { + inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); + } + + if (maxTokenCount != null) { + inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); + } + + return request.withInferenceConfig(inferenceConfig); + } + + @Override + public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + return request; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntity.java new file mode 100644 index 0000000000000..c68eaa1b81f54 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntity.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; + +public record AmazonBedrockMistralCompletionRequestEntity( + List messages, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Double topK, + @Nullable Integer maxTokenCount +) implements AmazonBedrockConverseRequestEntity { + + public AmazonBedrockMistralCompletionRequestEntity { + Objects.requireNonNull(messages); + } + + @Override + public ConverseRequest addMessages(ConverseRequest request) { + return request.withMessages(getConverseMessageList(messages)); + } + + @Override + public ConverseRequest addInferenceConfig(ConverseRequest request) { + if (temperature == null && topP == null && maxTokenCount == null) { + return request; + } + + InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + + if (temperature != null) { + inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); + } + + if (topP != null) { + inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); + } + + if (maxTokenCount != null) { + inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); + } + + return request.withInferenceConfig(inferenceConfig); + } + + @Override + public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + if (topK == null) { + return request; + } + + String topKField = Strings.format("{\"top_k\":%f}", topK.floatValue()); + return request.withAdditionalModelResponseFieldPaths(topKField); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntity.java new file mode 100644 index 0000000000000..d56035b80e9ef --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntity.java @@ -0,0 +1,63 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.InferenceConfiguration; + +import org.elasticsearch.core.Nullable; + +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseUtils.getConverseMessageList; + +public record AmazonBedrockTitanCompletionRequestEntity( + List messages, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Integer maxTokenCount +) implements AmazonBedrockConverseRequestEntity { + + public AmazonBedrockTitanCompletionRequestEntity { + Objects.requireNonNull(messages); + } + + @Override + public ConverseRequest addMessages(ConverseRequest request) { + return request.withMessages(getConverseMessageList(messages)); + } + + @Override + public ConverseRequest addInferenceConfig(ConverseRequest request) { + if (temperature == null && topP == null && maxTokenCount == null) { + return request; + } + + InferenceConfiguration inferenceConfig = new InferenceConfiguration(); + + if (temperature != null) { + inferenceConfig = inferenceConfig.withTemperature(temperature.floatValue()); + } + + if (topP != null) { + inferenceConfig = inferenceConfig.withTopP(topP.floatValue()); + } + + if (maxTokenCount != null) { + inferenceConfig = inferenceConfig.withMaxTokens(maxTokenCount); + } + + return request.withInferenceConfig(inferenceConfig); + } + + @Override + public ConverseRequest addAdditionalModelFields(ConverseRequest request) { + return request; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..edca5bc1bdf9c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntity.java @@ -0,0 +1,35 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +public record AmazonBedrockCohereEmbeddingsRequestEntity(List input) implements ToXContentObject { + + private static final String TEXTS_FIELD = "texts"; + private static final String INPUT_TYPE_FIELD = "input_type"; + private static final String INPUT_TYPE_SEARCH_DOCUMENT = "search_document"; + + public AmazonBedrockCohereEmbeddingsRequestEntity { + Objects.requireNonNull(input); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(TEXTS_FIELD, input); + builder.field(INPUT_TYPE_FIELD, INPUT_TYPE_SEARCH_DOCUMENT); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsEntityFactory.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsEntityFactory.java new file mode 100644 index 0000000000000..a31b033507264 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsEntityFactory.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; + +import java.util.Objects; + +public final class AmazonBedrockEmbeddingsEntityFactory { + public static ToXContent createEntity(AmazonBedrockEmbeddingsModel model, Truncator.TruncationResult truncationResult) { + Objects.requireNonNull(model); + Objects.requireNonNull(truncationResult); + + var serviceSettings = model.getServiceSettings(); + + var truncatedInput = truncationResult.input(); + if (truncatedInput == null || truncatedInput.isEmpty()) { + throw new ElasticsearchException("[input] cannot be null or empty"); + } + + switch (serviceSettings.provider()) { + case AMAZONTITAN -> { + if (truncatedInput.size() > 1) { + throw new ElasticsearchException("[input] cannot contain more than one string"); + } + return new AmazonBedrockTitanEmbeddingsRequestEntity(truncatedInput.get(0)); + } + case COHERE -> { + return new AmazonBedrockCohereEmbeddingsRequestEntity(truncatedInput); + } + default -> { + return null; + } + } + + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java new file mode 100644 index 0000000000000..96d3b3a3cc057 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockEmbeddingsRequest.java @@ -0,0 +1,99 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings; + +import com.amazonaws.services.bedrockruntime.model.InvokeModelRequest; +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xpack.core.common.socket.SocketAccess; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockBaseClient; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockJsonBuilder; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings.AmazonBedrockEmbeddingsResponseListener; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +public class AmazonBedrockEmbeddingsRequest extends AmazonBedrockRequest { + private final AmazonBedrockEmbeddingsModel embeddingsModel; + private final ToXContent requestEntity; + private final Truncator truncator; + private final Truncator.TruncationResult truncationResult; + private final AmazonBedrockProvider provider; + private ActionListener listener = null; + + public AmazonBedrockEmbeddingsRequest( + Truncator truncator, + Truncator.TruncationResult input, + AmazonBedrockEmbeddingsModel model, + ToXContent requestEntity, + @Nullable TimeValue timeout + ) { + super(model, timeout); + this.truncator = Objects.requireNonNull(truncator); + this.truncationResult = Objects.requireNonNull(input); + this.requestEntity = Objects.requireNonNull(requestEntity); + this.embeddingsModel = model; + this.provider = model.provider(); + } + + public AmazonBedrockProvider provider() { + return provider; + } + + @Override + protected void executeRequest(AmazonBedrockBaseClient client) { + try { + var jsonBuilder = new AmazonBedrockJsonBuilder(requestEntity); + var bodyAsString = jsonBuilder.getStringContent(); + + var charset = StandardCharsets.UTF_8; + var bodyBuffer = charset.encode(bodyAsString); + + var invokeModelRequest = new InvokeModelRequest().withModelId(embeddingsModel.model()).withBody(bodyBuffer); + + SocketAccess.doPrivileged(() -> client.invokeModel(invokeModelRequest, listener)); + } catch (IOException e) { + listener.onFailure(new RuntimeException(e)); + } + } + + @Override + public Request truncate() { + var truncatedInput = truncator.truncate(truncationResult.input()); + return new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, embeddingsModel, requestEntity, timeout); + } + + @Override + public boolean[] getTruncationInfo() { + return truncationResult.truncated().clone(); + } + + @Override + public TaskType taskType() { + return TaskType.TEXT_EMBEDDING; + } + + public void executeEmbeddingsRequest( + AmazonBedrockBaseClient awsBedrockClient, + AmazonBedrockEmbeddingsResponseListener embeddingsResponseListener + ) { + this.listener = embeddingsResponseListener; + this.executeRequest(awsBedrockClient); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockTitanEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockTitanEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..f55edd0442913 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockTitanEmbeddingsRequestEntity.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Objects; + +public record AmazonBedrockTitanEmbeddingsRequestEntity(String inputText) implements ToXContentObject { + + private static final String INPUT_TEXT_FIELD = "inputText"; + + public AmazonBedrockTitanEmbeddingsRequestEntity { + Objects.requireNonNull(inputText); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.field(INPUT_TEXT_FIELD, inputText); + builder.endObject(); + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponse.java new file mode 100644 index 0000000000000..54b05137acda3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponse.java @@ -0,0 +1,15 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; + +public abstract class AmazonBedrockResponse { + public abstract InferenceServiceResults accept(AmazonBedrockRequest request); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseHandler.java new file mode 100644 index 0000000000000..9dc15ea667c1d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseHandler.java @@ -0,0 +1,23 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; + +public abstract class AmazonBedrockResponseHandler implements ResponseHandler { + @Override + public final void validateResponse(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) + throws RetryException { + // do nothing as the AWS SDK will take care of validation for us + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseListener.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseListener.java new file mode 100644 index 0000000000000..ce4d6d1dea655 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/AmazonBedrockResponseListener.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; + +import java.util.Objects; + +public class AmazonBedrockResponseListener { + protected final AmazonBedrockRequest request; + protected final ActionListener inferenceResultsListener; + protected final AmazonBedrockResponseHandler responseHandler; + + public AmazonBedrockResponseListener( + AmazonBedrockRequest request, + AmazonBedrockResponseHandler responseHandler, + ActionListener inferenceResultsListener + ) { + this.request = Objects.requireNonNull(request); + this.responseHandler = Objects.requireNonNull(responseHandler); + this.inferenceResultsListener = Objects.requireNonNull(inferenceResultsListener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponse.java new file mode 100644 index 0000000000000..5b3872e2c416a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponse.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseResult; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponse; + +import java.util.ArrayList; + +public class AmazonBedrockChatCompletionResponse extends AmazonBedrockResponse { + + private final ConverseResult result; + + public AmazonBedrockChatCompletionResponse(ConverseResult responseResult) { + this.result = responseResult; + } + + @Override + public InferenceServiceResults accept(AmazonBedrockRequest request) { + if (request instanceof AmazonBedrockChatCompletionRequest asChatCompletionRequest) { + return fromResponse(result); + } + + throw new ElasticsearchException("unexpected request type [" + request.getClass() + "]"); + } + + public static ChatCompletionResults fromResponse(ConverseResult response) { + var responseMessage = response.getOutput().getMessage(); + + var messageContents = responseMessage.getContent(); + var resultTexts = new ArrayList(); + for (var messageContent : messageContents) { + resultTexts.add(new ChatCompletionResults.Result(messageContent.getText())); + } + + return new ChatCompletionResults(resultTexts); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseHandler.java new file mode 100644 index 0000000000000..a24f54c50eef3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseHandler.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseResult; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; + +public class AmazonBedrockChatCompletionResponseHandler extends AmazonBedrockResponseHandler { + + private ConverseResult responseResult; + + public AmazonBedrockChatCompletionResponseHandler() {} + + @Override + public InferenceServiceResults parseResult(Request request, HttpResult result) throws RetryException { + var response = new AmazonBedrockChatCompletionResponse(responseResult); + return response.accept((AmazonBedrockRequest) request); + } + + @Override + public String getRequestType() { + return "Amazon Bedrock Chat Completion"; + } + + public void acceptChatCompletionResponseObject(ConverseResult response) { + this.responseResult = response; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseListener.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseListener.java new file mode 100644 index 0000000000000..be03ba84571eb --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/completion/AmazonBedrockChatCompletionResponseListener.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ConverseResult; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseListener; + +public class AmazonBedrockChatCompletionResponseListener extends AmazonBedrockResponseListener implements ActionListener { + + public AmazonBedrockChatCompletionResponseListener( + AmazonBedrockChatCompletionRequest request, + AmazonBedrockResponseHandler responseHandler, + ActionListener inferenceResultsListener + ) { + super(request, responseHandler, inferenceResultsListener); + } + + @Override + public void onResponse(ConverseResult result) { + ((AmazonBedrockChatCompletionResponseHandler) responseHandler).acceptChatCompletionResponseObject(result); + inferenceResultsListener.onResponse(responseHandler.parseResult(request, null)); + } + + @Override + public void onFailure(Exception e) { + throw new ElasticsearchException(e); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponse.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponse.java new file mode 100644 index 0000000000000..83fa790acbe68 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponse.java @@ -0,0 +1,132 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings; + +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.XContentUtils; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponse; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.common.xcontent.XContentParserUtils.parseList; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class AmazonBedrockEmbeddingsResponse extends AmazonBedrockResponse { + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Amazon Bedrock embeddings response"; + private final InvokeModelResult result; + + public AmazonBedrockEmbeddingsResponse(InvokeModelResult invokeModelResult) { + this.result = invokeModelResult; + } + + @Override + public InferenceServiceResults accept(AmazonBedrockRequest request) { + if (request instanceof AmazonBedrockEmbeddingsRequest asEmbeddingsRequest) { + return fromResponse(result, asEmbeddingsRequest.provider()); + } + + throw new ElasticsearchException("unexpected request type [" + request.getClass() + "]"); + } + + public static InferenceTextEmbeddingFloatResults fromResponse(InvokeModelResult response, AmazonBedrockProvider provider) { + var charset = StandardCharsets.UTF_8; + var bodyText = String.valueOf(charset.decode(response.getBody())); + + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, bodyText)) { + // move to the first token + jsonParser.nextToken(); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + var embeddingList = parseEmbeddings(jsonParser, provider); + + return new InferenceTextEmbeddingFloatResults(embeddingList); + } catch (IOException e) { + throw new ElasticsearchException(e); + } + } + + private static List parseEmbeddings( + XContentParser jsonParser, + AmazonBedrockProvider provider + ) throws IOException { + switch (provider) { + case AMAZONTITAN -> { + return parseTitanEmbeddings(jsonParser); + } + case COHERE -> { + return parseCohereEmbeddings(jsonParser); + } + default -> throw new IOException("Unsupported provider [" + provider + "]"); + } + } + + private static List parseTitanEmbeddings(XContentParser parser) + throws IOException { + /* + Titan response: + { + "embedding": [float, float, ...], + "inputTextTokenCount": int + } + */ + positionParserAtTokenAfterField(parser, "embedding", FAILED_TO_FIND_FIELD_TEMPLATE); + List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); + var embeddingValues = InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValuesList); + return List.of(embeddingValues); + } + + private static List parseCohereEmbeddings(XContentParser parser) + throws IOException { + /* + Cohere response: + { + "embeddings": [ + [< array of 1024 floats >], + ... + ], + "id": string, + "response_type" : "embeddings_floats", + "texts": [string] + } + */ + positionParserAtTokenAfterField(parser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); + + List embeddingList = parseList( + parser, + AmazonBedrockEmbeddingsResponse::parseCohereEmbeddingsListItem + ); + + return embeddingList; + } + + private static InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding parseCohereEmbeddingsListItem(XContentParser parser) + throws IOException { + List embeddingValuesList = parseList(parser, XContentUtils::parseFloat); + return InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding.of(embeddingValuesList); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseHandler.java new file mode 100644 index 0000000000000..a3fb68ee23486 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseHandler.java @@ -0,0 +1,37 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings; + +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.RetryException; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; + +public class AmazonBedrockEmbeddingsResponseHandler extends AmazonBedrockResponseHandler { + + private InvokeModelResult invokeModelResult; + + @Override + public InferenceServiceResults parseResult(Request request, HttpResult result) throws RetryException { + var responseParser = new AmazonBedrockEmbeddingsResponse(invokeModelResult); + return responseParser.accept((AmazonBedrockRequest) request); + } + + @Override + public String getRequestType() { + return "Amazon Bedrock Embeddings"; + } + + public void acceptEmbeddingsResult(InvokeModelResult result) { + this.invokeModelResult = result; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseListener.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseListener.java new file mode 100644 index 0000000000000..36519ae31ff60 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/amazonbedrock/embeddings/AmazonBedrockEmbeddingsResponseListener.java @@ -0,0 +1,38 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings; + +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseListener; + +public class AmazonBedrockEmbeddingsResponseListener extends AmazonBedrockResponseListener implements ActionListener { + + public AmazonBedrockEmbeddingsResponseListener( + AmazonBedrockEmbeddingsRequest request, + AmazonBedrockResponseHandler responseHandler, + ActionListener inferenceResultsListener + ) { + super(request, responseHandler, inferenceResultsListener); + } + + @Override + public void onResponse(InvokeModelResult result) { + ((AmazonBedrockEmbeddingsResponseHandler) responseHandler).acceptEmbeddingsResult(result); + inferenceResultsListener.onResponse(responseHandler.parseResult(request, null)); + } + + @Override + public void onFailure(Exception e) { + inferenceResultsListener.onFailure(e); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereRankedResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereRankedResponseEntity.java index 7f71933676ee0..c5bb536833e89 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereRankedResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/cohere/CohereRankedResponseEntity.java @@ -161,5 +161,5 @@ private static RankedDocsResults.RankedDoc parseRankedDocObject(XContentParser p private CohereRankedResponseEntity() {} - static String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Cohere embeddings response"; + static String FAILED_TO_FIND_FIELD_TEMPLATE = "Failed to find required field [%s] in Cohere rerank response"; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 966cc029232b1..9f810b829bea9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -23,6 +23,8 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.TextEmbedding; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsFeatureFlag; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; import java.net.URI; @@ -37,6 +39,9 @@ import java.util.stream.Collectors; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.ENABLED; +import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.MAX_NUMBER_OF_ALLOCATIONS; +import static org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings.MIN_NUMBER_OF_ALLOCATIONS; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; public final class ServiceUtils { @@ -126,6 +131,20 @@ public static Object removeAsOneOfTypes( return null; } + public static AdaptiveAllocationsSettings removeAsAdaptiveAllocationsSettings(Map sourceMap, String key) { + if (AdaptiveAllocationsFeatureFlag.isEnabled() == false) { + return null; + } + Map settingsMap = ServiceUtils.removeFromMap(sourceMap, key); + return settingsMap == null + ? null + : new AdaptiveAllocationsSettings( + ServiceUtils.removeAsType(settingsMap, ENABLED.getPreferredName(), Boolean.class), + ServiceUtils.removeAsType(settingsMap, MIN_NUMBER_OF_ALLOCATIONS.getPreferredName(), Integer.class), + ServiceUtils.removeAsType(settingsMap, MAX_NUMBER_OF_ALLOCATIONS.getPreferredName(), Integer.class) + ); + } + @SuppressWarnings("unchecked") public static Map removeFromMap(Map sourceMap, String fieldName) { return (Map) sourceMap.remove(fieldName); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java new file mode 100644 index 0000000000000..1755dac2ac13f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockConstants.java @@ -0,0 +1,27 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +public class AmazonBedrockConstants { + public static final String ACCESS_KEY_FIELD = "access_key"; + public static final String SECRET_KEY_FIELD = "secret_key"; + public static final String REGION_FIELD = "region"; + public static final String MODEL_FIELD = "model"; + public static final String PROVIDER_FIELD = "provider"; + + public static final String TEMPERATURE_FIELD = "temperature"; + public static final String TOP_P_FIELD = "top_p"; + public static final String TOP_K_FIELD = "top_k"; + public static final String MAX_NEW_TOKENS_FIELD = "max_new_tokens"; + + public static final Double MIN_TEMPERATURE_TOP_P_TOP_K_VALUE = 0.0; + public static final Double MAX_TEMPERATURE_TOP_P_TOP_K_VALUE = 1.0; + + public static final int DEFAULT_MAX_CHUNK_SIZE = 2048; + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockModel.java new file mode 100644 index 0000000000000..13ca8bd7bd749 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockModel.java @@ -0,0 +1,88 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.amazonbedrock.AmazonBedrockActionVisitor; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.Map; + +public abstract class AmazonBedrockModel extends Model { + + protected String region; + protected String model; + protected AmazonBedrockProvider provider; + protected RateLimitSettings rateLimitSettings; + + protected AmazonBedrockModel(ModelConfigurations modelConfigurations, ModelSecrets secrets) { + super(modelConfigurations, secrets); + setPropertiesFromServiceSettings((AmazonBedrockServiceSettings) modelConfigurations.getServiceSettings()); + } + + protected AmazonBedrockModel(Model model, TaskSettings taskSettings) { + super(model, taskSettings); + + if (model instanceof AmazonBedrockModel bedrockModel) { + setPropertiesFromServiceSettings(bedrockModel.getServiceSettings()); + } + } + + protected AmazonBedrockModel(Model model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + if (serviceSettings instanceof AmazonBedrockServiceSettings bedrockServiceSettings) { + setPropertiesFromServiceSettings(bedrockServiceSettings); + } + } + + protected AmazonBedrockModel(ModelConfigurations modelConfigurations) { + super(modelConfigurations); + setPropertiesFromServiceSettings((AmazonBedrockServiceSettings) modelConfigurations.getServiceSettings()); + } + + public String region() { + return region; + } + + public String model() { + return model; + } + + public AmazonBedrockProvider provider() { + return provider; + } + + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + private void setPropertiesFromServiceSettings(AmazonBedrockServiceSettings serviceSettings) { + this.region = serviceSettings.region(); + this.model = serviceSettings.model(); + this.provider = serviceSettings.provider(); + this.rateLimitSettings = serviceSettings.rateLimitSettings(); + } + + public abstract ExecutableAction accept(AmazonBedrockActionVisitor creator, Map taskSettings); + + @Override + public AmazonBedrockServiceSettings getServiceSettings() { + return (AmazonBedrockServiceSettings) super.getServiceSettings(); + } + + @Override + public AmazonBedrockSecretSettings getSecretSettings() { + return (AmazonBedrockSecretSettings) super.getSecretSettings(); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProvider.java new file mode 100644 index 0000000000000..340a5a65f0969 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProvider.java @@ -0,0 +1,30 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +import java.util.Locale; + +public enum AmazonBedrockProvider { + AMAZONTITAN, + ANTHROPIC, + AI21LABS, + COHERE, + META, + MISTRAL; + + public static String NAME = "amazon_bedrock_provider"; + + public static AmazonBedrockProvider fromString(String name) { + return valueOf(name.trim().toUpperCase(Locale.ROOT)); + } + + @Override + public String toString() { + return name().toLowerCase(Locale.ROOT); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProviderCapabilities.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProviderCapabilities.java new file mode 100644 index 0000000000000..28b10ef294bda --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockProviderCapabilities.java @@ -0,0 +1,102 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; + +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.DEFAULT_MAX_CHUNK_SIZE; + +public final class AmazonBedrockProviderCapabilities { + private static final List embeddingProviders = List.of( + AmazonBedrockProvider.AMAZONTITAN, + AmazonBedrockProvider.COHERE + ); + + private static final List chatCompletionProviders = List.of( + AmazonBedrockProvider.AMAZONTITAN, + AmazonBedrockProvider.ANTHROPIC, + AmazonBedrockProvider.AI21LABS, + AmazonBedrockProvider.COHERE, + AmazonBedrockProvider.META, + AmazonBedrockProvider.MISTRAL + ); + + private static final List chatCompletionProvidersWithTopK = List.of( + AmazonBedrockProvider.ANTHROPIC, + AmazonBedrockProvider.COHERE, + AmazonBedrockProvider.MISTRAL + ); + + private static final Map embeddingsDefaultSimilarityMeasure = Map.of( + AmazonBedrockProvider.AMAZONTITAN, + SimilarityMeasure.COSINE, + AmazonBedrockProvider.COHERE, + SimilarityMeasure.DOT_PRODUCT + ); + + private static final Map embeddingsDefaultChunkSize = Map.of( + AmazonBedrockProvider.AMAZONTITAN, + 8192, + AmazonBedrockProvider.COHERE, + 2048 + ); + + private static final Map embeddingsMaxBatchSize = Map.of( + AmazonBedrockProvider.AMAZONTITAN, + 1, + AmazonBedrockProvider.COHERE, + 96 + ); + + public static boolean providerAllowsTaskType(AmazonBedrockProvider provider, TaskType taskType) { + switch (taskType) { + case COMPLETION -> { + return chatCompletionProviders.contains(provider); + } + case TEXT_EMBEDDING -> { + return embeddingProviders.contains(provider); + } + default -> { + return false; + } + } + } + + public static boolean chatCompletionProviderHasTopKParameter(AmazonBedrockProvider provider) { + return chatCompletionProvidersWithTopK.contains(provider); + } + + public static SimilarityMeasure getProviderDefaultSimilarityMeasure(AmazonBedrockProvider provider) { + if (embeddingsDefaultSimilarityMeasure.containsKey(provider)) { + return embeddingsDefaultSimilarityMeasure.get(provider); + } + + return SimilarityMeasure.COSINE; + } + + public static int getEmbeddingsProviderDefaultChunkSize(AmazonBedrockProvider provider) { + if (embeddingsDefaultChunkSize.containsKey(provider)) { + return embeddingsDefaultChunkSize.get(provider); + } + + return DEFAULT_MAX_CHUNK_SIZE; + } + + public static int getEmbeddingsMaxBatchSize(AmazonBedrockProvider provider) { + if (embeddingsMaxBatchSize.containsKey(provider)) { + return embeddingsMaxBatchSize.get(provider); + } + + return 1; + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettings.java new file mode 100644 index 0000000000000..9e6328ce1c358 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettings.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SecretSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.TransportVersions.ML_INFERENCE_AMAZON_BEDROCK_ADDED; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.ACCESS_KEY_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.SECRET_KEY_FIELD; + +public class AmazonBedrockSecretSettings implements SecretSettings { + public static final String NAME = "amazon_bedrock_secret_settings"; + + public final SecureString accessKey; + public final SecureString secretKey; + + public static AmazonBedrockSecretSettings fromMap(@Nullable Map map) { + if (map == null) { + return null; + } + + ValidationException validationException = new ValidationException(); + SecureString secureAccessKey = extractRequiredSecureString( + map, + ACCESS_KEY_FIELD, + ModelSecrets.SECRET_SETTINGS, + validationException + ); + SecureString secureSecretKey = extractRequiredSecureString( + map, + SECRET_KEY_FIELD, + ModelSecrets.SECRET_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AmazonBedrockSecretSettings(secureAccessKey, secureSecretKey); + } + + public AmazonBedrockSecretSettings(SecureString accessKey, SecureString secretKey) { + this.accessKey = Objects.requireNonNull(accessKey); + this.secretKey = Objects.requireNonNull(secretKey); + } + + public AmazonBedrockSecretSettings(StreamInput in) throws IOException { + this.accessKey = in.readSecureString(); + this.secretKey = in.readSecureString(); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return ML_INFERENCE_AMAZON_BEDROCK_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeSecureString(accessKey); + out.writeSecureString(secretKey); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(ACCESS_KEY_FIELD, accessKey.toString()); + builder.field(SECRET_KEY_FIELD, secretKey.toString()); + + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + AmazonBedrockSecretSettings that = (AmazonBedrockSecretSettings) object; + return Objects.equals(accessKey, that.accessKey) && Objects.equals(secretKey, that.secretKey); + } + + @Override + public int hashCode() { + return Objects.hash(accessKey, secretKey); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java new file mode 100644 index 0000000000000..459ca367058f8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -0,0 +1,328 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.common.Strings; +import org.elasticsearch.core.IOUtils; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.amazonbedrock.AmazonBedrockActionCreator; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.TransportVersions.ML_INFERENCE_AMAZON_BEDROCK_ADDED; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_K_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.chatCompletionProviderHasTopKParameter; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.getEmbeddingsMaxBatchSize; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.getProviderDefaultSimilarityMeasure; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProviderCapabilities.providerAllowsTaskType; + +public class AmazonBedrockService extends SenderService { + public static final String NAME = "amazonbedrock"; + + private final Sender amazonBedrockSender; + + public AmazonBedrockService( + HttpRequestSender.Factory httpSenderFactory, + AmazonBedrockRequestSender.Factory amazonBedrockFactory, + ServiceComponents serviceComponents + ) { + super(httpSenderFactory, serviceComponents); + this.amazonBedrockSender = amazonBedrockFactory.createSender(); + } + + @Override + protected void doInfer( + Model model, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout); + if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) { + var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings); + action.execute(new DocumentsOnlyInput(input), timeout, listener); + } else { + listener.onFailure(createInvalidModelException(model)); + } + } + + @Override + protected void doInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + throw new UnsupportedOperationException("Amazon Bedrock service does not support inference with query input"); + } + + @Override + protected void doChunkedInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + ChunkingOptions chunkingOptions, + TimeValue timeout, + ActionListener> listener + ) { + var actionCreator = new AmazonBedrockActionCreator(amazonBedrockSender, this.getServiceComponents(), timeout); + if (model instanceof AmazonBedrockModel baseAmazonBedrockModel) { + var maxBatchSize = getEmbeddingsMaxBatchSize(baseAmazonBedrockModel.provider()); + var batchedRequests = new EmbeddingRequestChunker(input, maxBatchSize, EmbeddingRequestChunker.EmbeddingType.FLOAT) + .batchRequestsWithListeners(listener); + for (var request : batchedRequests) { + var action = baseAmazonBedrockModel.accept(actionCreator, taskSettings); + action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); + } + } else { + listener.onFailure(createInvalidModelException(model)); + } + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + Set platformArchitectures, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + AmazonBedrockModel model = createModel( + modelId, + taskType, + serviceSettingsMap, + taskSettingsMap, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + @Override + public Model parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModel( + modelId, + taskType, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + parsePersistedConfigErrorMsg(modelId, NAME), + ConfigurationParseContext.PERSISTENT + ); + } + + @Override + public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + return createModel( + modelId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + parsePersistedConfigErrorMsg(modelId, NAME), + ConfigurationParseContext.PERSISTENT + ); + } + + private static AmazonBedrockModel createModel( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + switch (taskType) { + case TEXT_EMBEDDING -> { + var model = new AmazonBedrockEmbeddingsModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + checkProviderForTask(TaskType.TEXT_EMBEDDING, model.provider()); + return model; + } + case COMPLETION -> { + var model = new AmazonBedrockChatCompletionModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings, + context + ); + checkProviderForTask(TaskType.COMPLETION, model.provider()); + checkChatCompletionProviderForTopKParameter(model); + return model; + } + default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + } + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return ML_INFERENCE_AMAZON_BEDROCK_ADDED; + } + + /** + * For text embedding models get the embedding size and + * update the service settings. + * + * @param model The new model + * @param listener The listener + */ + @Override + public void checkModelConfig(Model model, ActionListener listener) { + if (model instanceof AmazonBedrockEmbeddingsModel embeddingsModel) { + ServiceUtils.getEmbeddingSize( + model, + this, + listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size))) + ); + } else { + listener.onResponse(model); + } + } + + private AmazonBedrockEmbeddingsModel updateModelWithEmbeddingDetails(AmazonBedrockEmbeddingsModel model, int embeddingSize) { + AmazonBedrockEmbeddingsServiceSettings serviceSettings = model.getServiceSettings(); + if (serviceSettings.dimensionsSetByUser() + && serviceSettings.dimensions() != null + && serviceSettings.dimensions() != embeddingSize) { + throw new ElasticsearchStatusException( + Strings.format( + "The retrieved embeddings size [%s] does not match the size specified in the settings [%s]. " + + "Please recreate the [%s] configuration with the correct dimensions", + embeddingSize, + serviceSettings.dimensions(), + model.getConfigurations().getInferenceEntityId() + ), + RestStatus.BAD_REQUEST + ); + } + + var similarityFromModel = serviceSettings.similarity(); + var similarityToUse = similarityFromModel == null ? getProviderDefaultSimilarityMeasure(model.provider()) : similarityFromModel; + + AmazonBedrockEmbeddingsServiceSettings settingsToUse = new AmazonBedrockEmbeddingsServiceSettings( + serviceSettings.region(), + serviceSettings.model(), + serviceSettings.provider(), + embeddingSize, + serviceSettings.dimensionsSetByUser(), + serviceSettings.maxInputTokens(), + similarityToUse, + serviceSettings.rateLimitSettings() + ); + + return new AmazonBedrockEmbeddingsModel(model, settingsToUse); + } + + private static void checkProviderForTask(TaskType taskType, AmazonBedrockProvider provider) { + if (providerAllowsTaskType(provider, taskType) == false) { + throw new ElasticsearchStatusException( + Strings.format("The [%s] task type for provider [%s] is not available", taskType, provider), + RestStatus.BAD_REQUEST + ); + } + } + + private static void checkChatCompletionProviderForTopKParameter(AmazonBedrockChatCompletionModel model) { + var taskSettings = model.getTaskSettings(); + if (taskSettings.topK() != null) { + if (chatCompletionProviderHasTopKParameter(model.provider()) == false) { + throw new ElasticsearchStatusException( + Strings.format("The [%s] task parameter is not available for provider [%s]", TOP_K_FIELD, model.provider()), + RestStatus.BAD_REQUEST + ); + } + } + } + + @Override + public void close() throws IOException { + super.close(); + IOUtils.closeWhileHandlingException(amazonBedrockSender); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceSettings.java new file mode 100644 index 0000000000000..13c7c0a8c5938 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceSettings.java @@ -0,0 +1,141 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.EnumSet; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.TransportVersions.ML_INFERENCE_AMAZON_BEDROCK_ADDED; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredEnum; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD; + +public abstract class AmazonBedrockServiceSettings extends FilteredXContentObject implements ServiceSettings { + + protected static final String AMAZON_BEDROCK_BASE_NAME = "amazon_bedrock"; + + protected final String region; + protected final String model; + protected final AmazonBedrockProvider provider; + protected final RateLimitSettings rateLimitSettings; + + // the default requests per minute are defined as per-model in the "Runtime quotas" on AWS + // see: https://docs.aws.amazon.com/bedrock/latest/userguide/quotas.html + // setting this to 240 requests per minute (4 requests / sec) is a sane default for us as it should be enough for + // decent throughput without exceeding the minimal for _most_ items. The user should consult + // the table above if using a model that might have a lesser limit (e.g. Anthropic Claude 3.5) + protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(240); + + protected static AmazonBedrockServiceSettings.BaseAmazonBedrockCommonSettings fromMap( + Map map, + ValidationException validationException, + ConfigurationParseContext context + ) { + String model = extractRequiredString(map, MODEL_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException); + String region = extractRequiredString(map, REGION_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException); + AmazonBedrockProvider provider = extractRequiredEnum( + map, + PROVIDER_FIELD, + ModelConfigurations.SERVICE_SETTINGS, + AmazonBedrockProvider::fromString, + EnumSet.allOf(AmazonBedrockProvider.class), + validationException + ); + RateLimitSettings rateLimitSettings = RateLimitSettings.of( + map, + DEFAULT_RATE_LIMIT_SETTINGS, + validationException, + AMAZON_BEDROCK_BASE_NAME, + context + ); + + return new BaseAmazonBedrockCommonSettings(region, model, provider, rateLimitSettings); + } + + protected record BaseAmazonBedrockCommonSettings( + String region, + String model, + AmazonBedrockProvider provider, + @Nullable RateLimitSettings rateLimitSettings + ) {} + + protected AmazonBedrockServiceSettings(StreamInput in) throws IOException { + this.region = in.readString(); + this.model = in.readString(); + this.provider = in.readEnum(AmazonBedrockProvider.class); + this.rateLimitSettings = new RateLimitSettings(in); + } + + protected AmazonBedrockServiceSettings( + String region, + String model, + AmazonBedrockProvider provider, + @Nullable RateLimitSettings rateLimitSettings + ) { + this.region = Objects.requireNonNull(region); + this.model = Objects.requireNonNull(model); + this.provider = Objects.requireNonNull(provider); + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return ML_INFERENCE_AMAZON_BEDROCK_ADDED; + } + + public String region() { + return region; + } + + public String model() { + return model; + } + + public AmazonBedrockProvider provider() { + return provider; + } + + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(region); + out.writeString(model); + out.writeEnum(provider); + rateLimitSettings.writeTo(out); + } + + public void addBaseXContent(XContentBuilder builder, Params params) throws IOException { + toXContentFragmentOfExposedFields(builder, params); + } + + protected void addXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(REGION_FIELD, region); + builder.field(MODEL_FIELD, model); + builder.field(PROVIDER_FIELD, provider.name()); + rateLimitSettings.toXContent(builder, params); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModel.java new file mode 100644 index 0000000000000..27dc607d671aa --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModel.java @@ -0,0 +1,83 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.completion; + +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.amazonbedrock.AmazonBedrockActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettings; + +import java.util.Map; + +public class AmazonBedrockChatCompletionModel extends AmazonBedrockModel { + + public static AmazonBedrockChatCompletionModel of(AmazonBedrockChatCompletionModel completionModel, Map taskSettings) { + if (taskSettings == null || taskSettings.isEmpty()) { + return completionModel; + } + + var requestTaskSettings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap(taskSettings); + var taskSettingsToUse = AmazonBedrockChatCompletionTaskSettings.of(completionModel.getTaskSettings(), requestTaskSettings); + return new AmazonBedrockChatCompletionModel(completionModel, taskSettingsToUse); + } + + public AmazonBedrockChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String name, + Map serviceSettings, + Map taskSettings, + Map secretSettings, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + name, + AmazonBedrockChatCompletionServiceSettings.fromMap(serviceSettings, context), + AmazonBedrockChatCompletionTaskSettings.fromMap(taskSettings), + AmazonBedrockSecretSettings.fromMap(secretSettings) + ); + } + + public AmazonBedrockChatCompletionModel( + String inferenceEntityId, + TaskType taskType, + String service, + AmazonBedrockChatCompletionServiceSettings serviceSettings, + AmazonBedrockChatCompletionTaskSettings taskSettings, + AmazonBedrockSecretSettings secrets + ) { + super(new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), new ModelSecrets(secrets)); + } + + public AmazonBedrockChatCompletionModel(Model model, TaskSettings taskSettings) { + super(model, taskSettings); + } + + @Override + public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map taskSettings) { + return creator.create(this, taskSettings); + } + + @Override + public AmazonBedrockChatCompletionServiceSettings getServiceSettings() { + return (AmazonBedrockChatCompletionServiceSettings) super.getServiceSettings(); + } + + @Override + public AmazonBedrockChatCompletionTaskSettings getTaskSettings() { + return (AmazonBedrockChatCompletionTaskSettings) super.getTaskSettings(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionRequestTaskSettings.java new file mode 100644 index 0000000000000..5985dcd56c5d2 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionRequestTaskSettings.java @@ -0,0 +1,90 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.completion; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; + +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalDoubleInRange; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MAX_NEW_TOKENS_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MAX_TEMPERATURE_TOP_P_TOP_K_VALUE; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MIN_TEMPERATURE_TOP_P_TOP_K_VALUE; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TEMPERATURE_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_K_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_P_FIELD; + +public record AmazonBedrockChatCompletionRequestTaskSettings( + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Double topK, + @Nullable Integer maxNewTokens +) { + + public static final AmazonBedrockChatCompletionRequestTaskSettings EMPTY_SETTINGS = new AmazonBedrockChatCompletionRequestTaskSettings( + null, + null, + null, + null + ); + + /** + * Extracts the task settings from a map. All settings are considered optional and the absence of a setting + * does not throw an error. + * + * @param map the settings received from a request + * @return a {@link AmazonBedrockChatCompletionRequestTaskSettings} + */ + public static AmazonBedrockChatCompletionRequestTaskSettings fromMap(Map map) { + if (map.isEmpty()) { + return AmazonBedrockChatCompletionRequestTaskSettings.EMPTY_SETTINGS; + } + + ValidationException validationException = new ValidationException(); + + var temperature = extractOptionalDoubleInRange( + map, + TEMPERATURE_FIELD, + MIN_TEMPERATURE_TOP_P_TOP_K_VALUE, + MAX_TEMPERATURE_TOP_P_TOP_K_VALUE, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + var topP = extractOptionalDoubleInRange( + map, + TOP_P_FIELD, + MIN_TEMPERATURE_TOP_P_TOP_K_VALUE, + MAX_TEMPERATURE_TOP_P_TOP_K_VALUE, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + var topK = extractOptionalDoubleInRange( + map, + TOP_K_FIELD, + MIN_TEMPERATURE_TOP_P_TOP_K_VALUE, + MAX_TEMPERATURE_TOP_P_TOP_K_VALUE, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + Integer maxNewTokens = extractOptionalPositiveInteger( + map, + MAX_NEW_TOKENS_FIELD, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AmazonBedrockChatCompletionRequestTaskSettings(temperature, topP, topK, maxNewTokens); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionServiceSettings.java new file mode 100644 index 0000000000000..fc3d09c6eea7a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionServiceSettings.java @@ -0,0 +1,93 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.completion; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +public class AmazonBedrockChatCompletionServiceSettings extends AmazonBedrockServiceSettings { + public static final String NAME = "amazon_bedrock_chat_completion_service_settings"; + + public static AmazonBedrockChatCompletionServiceSettings fromMap( + Map serviceSettings, + ConfigurationParseContext context + ) { + ValidationException validationException = new ValidationException(); + + var baseSettings = AmazonBedrockServiceSettings.fromMap(serviceSettings, validationException, context); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AmazonBedrockChatCompletionServiceSettings( + baseSettings.region(), + baseSettings.model(), + baseSettings.provider(), + baseSettings.rateLimitSettings() + ); + } + + public AmazonBedrockChatCompletionServiceSettings( + String region, + String model, + AmazonBedrockProvider provider, + RateLimitSettings rateLimitSettings + ) { + super(region, model, provider, rateLimitSettings); + } + + public AmazonBedrockChatCompletionServiceSettings(StreamInput in) throws IOException { + super(in); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + super.addBaseXContent(builder, params); + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + super.addXContentFragmentOfExposedFields(builder, params); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AmazonBedrockChatCompletionServiceSettings that = (AmazonBedrockChatCompletionServiceSettings) o; + + return Objects.equals(region, that.region) + && Objects.equals(provider, that.provider) + && Objects.equals(model, that.model) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(region, model, provider, rateLimitSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettings.java new file mode 100644 index 0000000000000..e689e68794e1f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettings.java @@ -0,0 +1,190 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.TransportVersions.ML_INFERENCE_AMAZON_BEDROCK_ADDED; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalDoubleInRange; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MAX_NEW_TOKENS_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MAX_TEMPERATURE_TOP_P_TOP_K_VALUE; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MIN_TEMPERATURE_TOP_P_TOP_K_VALUE; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TEMPERATURE_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_K_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_P_FIELD; + +public class AmazonBedrockChatCompletionTaskSettings implements TaskSettings { + public static final String NAME = "amazon_bedrock_chat_completion_task_settings"; + + public static final AmazonBedrockChatCompletionRequestTaskSettings EMPTY_SETTINGS = new AmazonBedrockChatCompletionRequestTaskSettings( + null, + null, + null, + null + ); + + public static AmazonBedrockChatCompletionTaskSettings fromMap(Map settings) { + ValidationException validationException = new ValidationException(); + + Double temperature = extractOptionalDoubleInRange( + settings, + TEMPERATURE_FIELD, + MIN_TEMPERATURE_TOP_P_TOP_K_VALUE, + MAX_TEMPERATURE_TOP_P_TOP_K_VALUE, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + Double topP = extractOptionalDoubleInRange( + settings, + TOP_P_FIELD, + MIN_TEMPERATURE_TOP_P_TOP_K_VALUE, + MAX_TEMPERATURE_TOP_P_TOP_K_VALUE, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + Double topK = extractOptionalDoubleInRange( + settings, + TOP_K_FIELD, + MIN_TEMPERATURE_TOP_P_TOP_K_VALUE, + MAX_TEMPERATURE_TOP_P_TOP_K_VALUE, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + Integer maxNewTokens = extractOptionalPositiveInteger( + settings, + MAX_NEW_TOKENS_FIELD, + ModelConfigurations.TASK_SETTINGS, + validationException + ); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new AmazonBedrockChatCompletionTaskSettings(temperature, topP, topK, maxNewTokens); + } + + public static AmazonBedrockChatCompletionTaskSettings of( + AmazonBedrockChatCompletionTaskSettings originalSettings, + AmazonBedrockChatCompletionRequestTaskSettings requestSettings + ) { + var temperature = requestSettings.temperature() == null ? originalSettings.temperature() : requestSettings.temperature(); + var topP = requestSettings.topP() == null ? originalSettings.topP() : requestSettings.topP(); + var topK = requestSettings.topK() == null ? originalSettings.topK() : requestSettings.topK(); + var maxNewTokens = requestSettings.maxNewTokens() == null ? originalSettings.maxNewTokens() : requestSettings.maxNewTokens(); + + return new AmazonBedrockChatCompletionTaskSettings(temperature, topP, topK, maxNewTokens); + } + + private final Double temperature; + private final Double topP; + private final Double topK; + private final Integer maxNewTokens; + + public AmazonBedrockChatCompletionTaskSettings( + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Double topK, + @Nullable Integer maxNewTokens + ) { + this.temperature = temperature; + this.topP = topP; + this.topK = topK; + this.maxNewTokens = maxNewTokens; + } + + public AmazonBedrockChatCompletionTaskSettings(StreamInput in) throws IOException { + this.temperature = in.readOptionalDouble(); + this.topP = in.readOptionalDouble(); + this.topK = in.readOptionalDouble(); + this.maxNewTokens = in.readOptionalVInt(); + } + + public Double temperature() { + return temperature; + } + + public Double topP() { + return topP; + } + + public Double topK() { + return topK; + } + + public Integer maxNewTokens() { + return maxNewTokens; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return ML_INFERENCE_AMAZON_BEDROCK_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalDouble(temperature); + out.writeOptionalDouble(topP); + out.writeOptionalDouble(topK); + out.writeOptionalVInt(maxNewTokens); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + { + if (temperature != null) { + builder.field(TEMPERATURE_FIELD, temperature); + } + if (topP != null) { + builder.field(TOP_P_FIELD, topP); + } + if (topK != null) { + builder.field(TOP_K_FIELD, topK); + } + if (maxNewTokens != null) { + builder.field(MAX_NEW_TOKENS_FIELD, maxNewTokens); + } + } + builder.endObject(); + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AmazonBedrockChatCompletionTaskSettings that = (AmazonBedrockChatCompletionTaskSettings) o; + return Objects.equals(temperature, that.temperature) + && Objects.equals(topP, that.topP) + && Objects.equals(topK, that.topK) + && Objects.equals(maxNewTokens, that.maxNewTokens); + } + + @Override + public int hashCode() { + return Objects.hash(temperature, topP, topK, maxNewTokens); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java new file mode 100644 index 0000000000000..0e3a954a03279 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModel.java @@ -0,0 +1,85 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.amazonbedrock.AmazonBedrockActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettings; + +import java.util.Map; + +public class AmazonBedrockEmbeddingsModel extends AmazonBedrockModel { + + public static AmazonBedrockEmbeddingsModel of(AmazonBedrockEmbeddingsModel embeddingsModel, Map taskSettings) { + if (taskSettings != null && taskSettings.isEmpty() == false) { + // no task settings allowed + var validationException = new ValidationException(); + validationException.addValidationError("Amazon Bedrock embeddings model cannot have task settings"); + throw validationException; + } + + return embeddingsModel; + } + + public AmazonBedrockEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + Map secretSettings, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + AmazonBedrockEmbeddingsServiceSettings.fromMap(serviceSettings, context), + new EmptyTaskSettings(), + AmazonBedrockSecretSettings.fromMap(secretSettings) + ); + } + + public AmazonBedrockEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + AmazonBedrockEmbeddingsServiceSettings serviceSettings, + TaskSettings taskSettings, + AmazonBedrockSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings()), + new ModelSecrets(secrets) + ); + } + + public AmazonBedrockEmbeddingsModel(Model model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + @Override + public ExecutableAction accept(AmazonBedrockActionVisitor creator, Map taskSettings) { + return creator.create(this, taskSettings); + } + + @Override + public AmazonBedrockEmbeddingsServiceSettings getServiceSettings() { + return (AmazonBedrockEmbeddingsServiceSettings) super.getServiceSettings(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..4bf037558c618 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettings.java @@ -0,0 +1,220 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; + +public class AmazonBedrockEmbeddingsServiceSettings extends AmazonBedrockServiceSettings { + public static final String NAME = "amazon_bedrock_embeddings_service_settings"; + static final String DIMENSIONS_SET_BY_USER = "dimensions_set_by_user"; + + private final Integer dimensions; + private final Boolean dimensionsSetByUser; + private final Integer maxInputTokens; + private final SimilarityMeasure similarity; + + public static AmazonBedrockEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + var settings = embeddingSettingsFromMap(map, validationException, context); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return settings; + } + + private static AmazonBedrockEmbeddingsServiceSettings embeddingSettingsFromMap( + Map map, + ValidationException validationException, + ConfigurationParseContext context + ) { + var baseSettings = AmazonBedrockServiceSettings.fromMap(map, validationException, context); + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + + Integer maxTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); + + Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException); + + switch (context) { + case REQUEST -> { + if (dimensionsSetByUser != null) { + validationException.addValidationError( + ServiceUtils.invalidSettingError(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS) + ); + } + + if (dims != null) { + validationException.addValidationError( + ServiceUtils.invalidSettingError(DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS) + ); + } + dimensionsSetByUser = false; + } + case PERSISTENT -> { + if (dimensionsSetByUser == null) { + validationException.addValidationError( + ServiceUtils.missingSettingErrorMsg(DIMENSIONS_SET_BY_USER, ModelConfigurations.SERVICE_SETTINGS) + ); + } + } + } + return new AmazonBedrockEmbeddingsServiceSettings( + baseSettings.region(), + baseSettings.model(), + baseSettings.provider(), + dims, + dimensionsSetByUser, + maxTokens, + similarity, + baseSettings.rateLimitSettings() + ); + } + + public AmazonBedrockEmbeddingsServiceSettings(StreamInput in) throws IOException { + super(in); + dimensions = in.readOptionalVInt(); + dimensionsSetByUser = in.readBoolean(); + maxInputTokens = in.readOptionalVInt(); + similarity = in.readOptionalEnum(SimilarityMeasure.class); + } + + public AmazonBedrockEmbeddingsServiceSettings( + String region, + String model, + AmazonBedrockProvider provider, + @Nullable Integer dimensions, + Boolean dimensionsSetByUser, + @Nullable Integer maxInputTokens, + @Nullable SimilarityMeasure similarity, + RateLimitSettings rateLimitSettings + ) { + super(region, model, provider, rateLimitSettings); + this.dimensions = dimensions; + this.dimensionsSetByUser = dimensionsSetByUser; + this.maxInputTokens = maxInputTokens; + this.similarity = similarity; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeOptionalVInt(dimensions); + out.writeBoolean(dimensionsSetByUser); + out.writeOptionalVInt(maxInputTokens); + out.writeOptionalEnum(similarity); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + super.addBaseXContent(builder, params); + builder.field(DIMENSIONS_SET_BY_USER, dimensionsSetByUser); + + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + super.addXContentFragmentOfExposedFields(builder, params); + + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + + return builder; + } + + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public Integer dimensions() { + return dimensions; + } + + public boolean dimensionsSetByUser() { + return this.dimensionsSetByUser; + } + + public Integer maxInputTokens() { + return maxInputTokens; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + AmazonBedrockEmbeddingsServiceSettings that = (AmazonBedrockEmbeddingsServiceSettings) o; + + return Objects.equals(region, that.region) + && Objects.equals(provider, that.provider) + && Objects.equals(model, that.model) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(dimensionsSetByUser, that.dimensionsSetByUser) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(similarity, that.similarity) + && Objects.equals(rateLimitSettings, that.rateLimitSettings); + } + + @Override + public int hashCode() { + return Objects.hash(region, model, provider, dimensions, dimensionsSetByUser, maxInputTokens, similarity, rateLimitSettings); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettings.java index 1a39cd67a70f3..d4a1fd938625e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettings.java @@ -33,8 +33,8 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; public class AzureAiStudioEmbeddingsServiceSettings extends AzureAiStudioServiceSettings { @@ -59,10 +59,15 @@ private static AzureAiStudioEmbeddingCommonFields embeddingSettingsFromMap( ConfigurationParseContext context ) { var baseSettings = AzureAiStudioServiceSettings.fromMap(map, validationException, context); - SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); - Integer dims = removeAsType(map, DIMENSIONS, Integer.class); - Integer maxTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer maxTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); Boolean dimensionsSetByUser = extractOptionalBoolean(map, DIMENSIONS_SET_BY_USER, validationException); switch (context) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 3facb78864831..3c75243770f97 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -24,10 +24,6 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; -import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.azureopenai.AzureOpenAiActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; @@ -44,7 +40,6 @@ import java.util.Map; import java.util.Set; -import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; @@ -246,19 +241,6 @@ protected void doChunkedInfer( } } - private static List translateToChunkedResults( - List inputs, - InferenceServiceResults inferenceResults - ) { - if (inferenceResults instanceof InferenceTextEmbeddingFloatResults textEmbeddingResults) { - return InferenceChunkedTextEmbeddingFloatResults.listOf(inputs, textEmbeddingResults); - } else if (inferenceResults instanceof ErrorInferenceResults error) { - return List.of(new ErrorChunkedInferenceResults(error.getException())); - } else { - throw createInvalidChunkedResultException(InferenceTextEmbeddingFloatResults.NAME, inferenceResults.getWriteableName()); - } - } - /** * For text embedding models get the embedding size and * update the service settings. diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java index 1c426815a83c0..a9e40569d4e7a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettings.java @@ -33,9 +33,9 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalBoolean; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.API_VERSION; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.DEPLOYMENT_ID; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.RESOURCE_NAME; @@ -88,8 +88,13 @@ private static CommonFields fromMap( String resourceName = extractRequiredString(map, RESOURCE_NAME, ModelConfigurations.SERVICE_SETTINGS, validationException); String deploymentId = extractRequiredString(map, DEPLOYMENT_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); String apiVersion = extractRequiredString(map, API_VERSION, ModelConfigurations.SERVICE_SETTINGS, validationException); - Integer dims = removeAsType(map, DIMENSIONS, Integer.class); - Integer maxTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); + Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer maxTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of( map, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalServiceSettings.java index 6c81cc9948b70..b74dbe482acc6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalServiceSettings.java @@ -9,11 +9,14 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import java.io.IOException; import java.util.Map; @@ -25,8 +28,13 @@ public class CustomElandInternalServiceSettings extends ElasticsearchInternalSer public static final String NAME = "custom_eland_model_internal_service_settings"; - public CustomElandInternalServiceSettings(int numAllocations, int numThreads, String modelId) { - super(numAllocations, numThreads, modelId); + public CustomElandInternalServiceSettings( + int numAllocations, + int numThreads, + String modelId, + AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) { + super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings); } /** @@ -50,6 +58,16 @@ public static CustomElandInternalServiceSettings fromMap(Map map validationException ); Integer numThreads = extractRequiredPositiveInteger(map, NUM_THREADS, ModelConfigurations.SERVICE_SETTINGS, validationException); + AdaptiveAllocationsSettings adaptiveAllocationsSettings = ServiceUtils.removeAsAdaptiveAllocationsSettings( + map, + ADAPTIVE_ALLOCATIONS + ); + if (adaptiveAllocationsSettings != null) { + ActionRequestValidationException exception = adaptiveAllocationsSettings.validate(); + if (exception != null) { + validationException.addValidationErrors(exception.validationErrors()); + } + } String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); if (validationException.validationErrors().isEmpty() == false) { @@ -59,12 +77,18 @@ public static CustomElandInternalServiceSettings fromMap(Map map var builder = new Builder() { @Override public CustomElandInternalServiceSettings build() { - return new CustomElandInternalServiceSettings(getNumAllocations(), getNumThreads(), getModelId()); + return new CustomElandInternalServiceSettings( + getNumAllocations(), + getNumThreads(), + getModelId(), + getAdaptiveAllocationsSettings() + ); } }; builder.setNumAllocations(numAllocations); builder.setNumThreads(numThreads); builder.setModelId(modelId); + builder.setAdaptiveAllocationsSettings(adaptiveAllocationsSettings); return builder.build(); } @@ -74,7 +98,14 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws } public CustomElandInternalServiceSettings(StreamInput in) throws IOException { - super(in.readVInt(), in.readVInt(), in.readString()); + super( + in.readVInt(), + in.readVInt(), + in.readString(), + in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS) + ? in.readOptionalWriteable(AdaptiveAllocationsSettings::new) + : null + ); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettings.java index 5ef9ce1a0507f..8413d06045601 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettings.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import java.io.IOException; @@ -107,19 +108,38 @@ private static CommonFields commonFieldsFromMap(Map map, Validat private final SimilarityMeasure similarityMeasure; private final DenseVectorFieldMapper.ElementType elementType; - public CustomElandInternalTextEmbeddingServiceSettings(int numAllocations, int numThreads, String modelId) { - this(numAllocations, numThreads, modelId, null, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT); + public CustomElandInternalTextEmbeddingServiceSettings( + int numAllocations, + int numThreads, + String modelId, + AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) { + this( + numAllocations, + numThreads, + modelId, + adaptiveAllocationsSettings, + null, + SimilarityMeasure.COSINE, + DenseVectorFieldMapper.ElementType.FLOAT + ); } public CustomElandInternalTextEmbeddingServiceSettings( int numAllocations, int numThreads, String modelId, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, Integer dimensions, SimilarityMeasure similarityMeasure, DenseVectorFieldMapper.ElementType elementType ) { - internalServiceSettings = new ElasticsearchInternalServiceSettings(numAllocations, numThreads, modelId); + internalServiceSettings = new ElasticsearchInternalServiceSettings( + numAllocations, + numThreads, + modelId, + adaptiveAllocationsSettings + ); this.dimensions = dimensions; this.similarityMeasure = Objects.requireNonNull(similarityMeasure); this.elementType = Objects.requireNonNull(elementType); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandModel.java index 5a82e73299b85..703fca8c74c31 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandModel.java @@ -37,6 +37,7 @@ public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentA var startRequest = new StartTrainedModelDeploymentAction.Request(internalServiceSettings.getModelId(), this.getInferenceEntityId()); startRequest.setNumberOfAllocations(internalServiceSettings.getNumAllocations()); startRequest.setThreadsPerAllocation(internalServiceSettings.getNumThreads()); + startRequest.setAdaptiveAllocationsSettings(internalServiceSettings.getAdaptiveAllocationsSettings()); startRequest.setWaitForState(STARTED); return startRequest; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index d5401f61823db..9dc88be16ddbb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -271,6 +271,7 @@ private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomE model.getServiceSettings().getElasticsearchInternalServiceSettings().getNumAllocations(), model.getServiceSettings().getElasticsearchInternalServiceSettings().getNumThreads(), model.getServiceSettings().getElasticsearchInternalServiceSettings().getModelId(), + model.getServiceSettings().getElasticsearchInternalServiceSettings().getAdaptiveAllocationsSettings(), embeddingSize, model.getServiceSettings().similarity(), model.getServiceSettings().elementType() diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java index 45d616074dded..ff4ef4ff0358f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceSettings.java @@ -9,9 +9,12 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings; import java.io.IOException; @@ -34,23 +37,46 @@ public static ElasticsearchInternalServiceSettings fromMap(Map m validationException ); Integer numThreads = extractRequiredPositiveInteger(map, NUM_THREADS, ModelConfigurations.SERVICE_SETTINGS, validationException); + AdaptiveAllocationsSettings adaptiveAllocationsSettings = ServiceUtils.removeAsAdaptiveAllocationsSettings( + map, + ADAPTIVE_ALLOCATIONS + ); + if (adaptiveAllocationsSettings != null) { + ActionRequestValidationException exception = adaptiveAllocationsSettings.validate(); + if (exception != null) { + validationException.addValidationErrors(exception.validationErrors()); + } + } String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); - // if an error occurred while parsing, we'll set these to an invalid value so we don't accidentally get a + // if an error occurred while parsing, we'll set these to an invalid value, so we don't accidentally get a // null pointer when doing unboxing return new ElasticsearchInternalServiceSettings( Objects.requireNonNullElse(numAllocations, FAILED_INT_PARSE_VALUE), Objects.requireNonNullElse(numThreads, FAILED_INT_PARSE_VALUE), - modelId + modelId, + adaptiveAllocationsSettings ); } - public ElasticsearchInternalServiceSettings(int numAllocations, int numThreads, String modelVariant) { - super(numAllocations, numThreads, modelVariant); + public ElasticsearchInternalServiceSettings( + int numAllocations, + int numThreads, + String modelVariant, + AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) { + super(numAllocations, numThreads, modelVariant, adaptiveAllocationsSettings); } public ElasticsearchInternalServiceSettings(StreamInput in) throws IOException { - super(in.readVInt(), in.readVInt(), in.readString()); + super( + in.readVInt(), + in.readVInt(), + in.readString(), + in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS) + ? in.readOptionalWriteable(AdaptiveAllocationsSettings::new) + : null + ); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java index 602f3a5c6c4e8..e4aa9616fb332 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettings.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.inference.services.elasticsearch; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -14,6 +16,7 @@ import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings; @@ -30,12 +33,24 @@ public class MultilingualE5SmallInternalServiceSettings extends ElasticsearchInt static final int DIMENSIONS = 384; static final SimilarityMeasure SIMILARITY = SimilarityMeasure.COSINE; - public MultilingualE5SmallInternalServiceSettings(int numAllocations, int numThreads, String modelId) { - super(numAllocations, numThreads, modelId); + public MultilingualE5SmallInternalServiceSettings( + int numAllocations, + int numThreads, + String modelId, + AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) { + super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings); } public MultilingualE5SmallInternalServiceSettings(StreamInput in) throws IOException { - super(in.readVInt(), in.readVInt(), in.readString()); + super( + in.readVInt(), + in.readVInt(), + in.readString(), + in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS) + ? in.readOptionalWriteable(AdaptiveAllocationsSettings::new) + : null + ); } /** @@ -66,7 +81,16 @@ private static RequestFields extractRequestFields(Map map, Valid validationException ); Integer numThreads = extractRequiredPositiveInteger(map, NUM_THREADS, ModelConfigurations.SERVICE_SETTINGS, validationException); - + AdaptiveAllocationsSettings adaptiveAllocationsSettings = ServiceUtils.removeAsAdaptiveAllocationsSettings( + map, + ADAPTIVE_ALLOCATIONS + ); + if (adaptiveAllocationsSettings != null) { + ActionRequestValidationException exception = adaptiveAllocationsSettings.validate(); + if (exception != null) { + validationException.addValidationErrors(exception.validationErrors()); + } + } String modelId = ServiceUtils.removeAsType(map, MODEL_ID, String.class); if (modelId != null) { if (ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_VALID_IDS.contains(modelId) == false) { @@ -79,23 +103,34 @@ private static RequestFields extractRequestFields(Map map, Valid } } - return new RequestFields(numAllocations, numThreads, modelId); + return new RequestFields(numAllocations, numThreads, modelId, adaptiveAllocationsSettings); } private static MultilingualE5SmallInternalServiceSettings.Builder createBuilder(RequestFields requestFields) { var builder = new InternalServiceSettings.Builder() { @Override public MultilingualE5SmallInternalServiceSettings build() { - return new MultilingualE5SmallInternalServiceSettings(getNumAllocations(), getNumThreads(), getModelId()); + return new MultilingualE5SmallInternalServiceSettings( + getNumAllocations(), + getNumThreads(), + getModelId(), + getAdaptiveAllocationsSettings() + ); } }; builder.setNumAllocations(requestFields.numAllocations); builder.setNumThreads(requestFields.numThreads); builder.setModelId(requestFields.modelId); + builder.setAdaptiveAllocationsSettings(requestFields.adaptiveAllocationsSettings); return builder; } - private record RequestFields(@Nullable Integer numAllocations, @Nullable Integer numThreads, @Nullable String modelId) {} + private record RequestFields( + @Nullable Integer numAllocations, + @Nullable Integer numThreads, + @Nullable String modelId, + @Nullable AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) {} @Override public boolean isFragment() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallModel.java index 60d68eb2fcee7..f22118d00cc29 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallModel.java @@ -47,6 +47,7 @@ public StartTrainedModelDeploymentAction.Request getStartTrainedModelDeploymentA ); startRequest.setNumberOfAllocations(this.getServiceSettings().getNumAllocations()); startRequest.setThreadsPerAllocation(this.getServiceSettings().getNumThreads()); + startRequest.setAdaptiveAllocationsSettings(this.getServiceSettings().getAdaptiveAllocationsSettings()); startRequest.setWaitForState(STARTED); return startRequest; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java index 11c97f8b8e37e..54434a7563dab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java @@ -216,6 +216,7 @@ private static StartTrainedModelDeploymentAction.Request startDeploymentRequest( ); startRequest.setNumberOfAllocations(serviceSettings.getNumAllocations()); startRequest.setThreadsPerAllocation(serviceSettings.getNumThreads()); + startRequest.setAdaptiveAllocationsSettings(serviceSettings.getAdaptiveAllocationsSettings()); startRequest.setWaitForState(STARTED); return startRequest; } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java index 603c218d4dd21..fcbf7394ccb33 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettings.java @@ -9,10 +9,13 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; +import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.settings.InternalServiceSettings; import java.io.IOException; @@ -45,6 +48,16 @@ public static ElserInternalServiceSettings.Builder fromMap(Map m validationException ); Integer numThreads = extractRequiredPositiveInteger(map, NUM_THREADS, ModelConfigurations.SERVICE_SETTINGS, validationException); + AdaptiveAllocationsSettings adaptiveAllocationsSettings = ServiceUtils.removeAsAdaptiveAllocationsSettings( + map, + ADAPTIVE_ALLOCATIONS + ); + if (adaptiveAllocationsSettings != null) { + ActionRequestValidationException exception = adaptiveAllocationsSettings.validate(); + if (exception != null) { + validationException.addValidationErrors(exception.validationErrors()); + } + } String modelId = extractOptionalString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); if (modelId != null && ElserInternalService.VALID_ELSER_MODEL_IDS.contains(modelId) == false) { @@ -58,17 +71,28 @@ public static ElserInternalServiceSettings.Builder fromMap(Map m var builder = new InternalServiceSettings.Builder() { @Override public ElserInternalServiceSettings build() { - return new ElserInternalServiceSettings(getNumAllocations(), getNumThreads(), getModelId()); + return new ElserInternalServiceSettings( + getNumAllocations(), + getNumThreads(), + getModelId(), + getAdaptiveAllocationsSettings() + ); } }; builder.setNumAllocations(numAllocations); builder.setNumThreads(numThreads); + builder.setAdaptiveAllocationsSettings(adaptiveAllocationsSettings); builder.setModelId(modelId); return builder; } - public ElserInternalServiceSettings(int numAllocations, int numThreads, String modelId) { - super(numAllocations, numThreads, modelId); + public ElserInternalServiceSettings( + int numAllocations, + int numThreads, + String modelId, + AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) { + super(numAllocations, numThreads, modelId, adaptiveAllocationsSettings); Objects.requireNonNull(modelId); } @@ -76,7 +100,10 @@ public ElserInternalServiceSettings(StreamInput in) throws IOException { super( in.readVInt(), in.readVInt(), - in.getTransportVersion().onOrAfter(TransportVersions.V_8_11_X) ? in.readString() : ElserInternalService.ELSER_V2_MODEL + in.getTransportVersion().onOrAfter(TransportVersions.V_8_11_X) ? in.readString() : ElserInternalService.ELSER_V2_MODEL, + in.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS) + ? in.readOptionalWriteable(AdaptiveAllocationsSettings::new) + : null ); } @@ -97,11 +124,14 @@ public void writeTo(StreamOutput out) throws IOException { if (out.getTransportVersion().onOrAfter(TransportVersions.V_8_11_X)) { out.writeString(getModelId()); } + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + out.writeOptionalWriteable(getAdaptiveAllocationsSettings()); + } } @Override public int hashCode() { - return Objects.hash(NAME, getNumAllocations(), getNumThreads(), getModelId()); + return Objects.hash(NAME, getNumAllocations(), getNumThreads(), getModelId(), getAdaptiveAllocationsSettings()); } @Override @@ -111,6 +141,7 @@ public boolean equals(Object o) { ElserInternalServiceSettings that = (ElserInternalServiceSettings) o; return getNumAllocations() == that.getNumAllocations() && getNumThreads() == that.getNumThreads() - && Objects.equals(getModelId(), that.getModelId()); + && Objects.equals(getModelId(), that.getModelId()) + && Objects.equals(getAdaptiveAllocationsSettings(), that.getAdaptiveAllocationsSettings()); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java index 62d06a4e0029c..2e4d546e1dc4c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java @@ -33,7 +33,6 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD; public class MistralEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { @@ -67,7 +66,7 @@ public static MistralEmbeddingsServiceSettings fromMap(Map map, MistralService.NAME, context ); - Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); if (validationException.validationErrors().isEmpty() == false) { throw validationException; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java index 080251bf1ba3a..d474e935fbda7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettings.java @@ -36,6 +36,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.convertToUri; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createOptionalUri; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; @@ -99,8 +100,13 @@ private static CommonFields fromMap( String url = extractOptionalString(map, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); String organizationId = extractOptionalString(map, ORGANIZATION, ModelConfigurations.SERVICE_SETTINGS, validationException); SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); - Integer maxInputTokens = removeAsType(map, MAX_INPUT_TOKENS, Integer.class); - Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + Integer maxInputTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); URI uri = convertToUri(url, URL, ModelConfigurations.SERVICE_SETTINGS, validationException); String modelId = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); RateLimitSettings rateLimitSettings = RateLimitSettings.of( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/InternalServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/InternalServiceSettings.java index 00bb48ae2302a..2cbe2f930c84d 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/InternalServiceSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/settings/InternalServiceSettings.java @@ -7,10 +7,12 @@ package org.elasticsearch.xpack.inference.services.settings; +import org.elasticsearch.TransportVersions; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import java.io.IOException; import java.util.Objects; @@ -20,15 +22,23 @@ public abstract class InternalServiceSettings implements ServiceSettings { public static final String NUM_ALLOCATIONS = "num_allocations"; public static final String NUM_THREADS = "num_threads"; public static final String MODEL_ID = "model_id"; + public static final String ADAPTIVE_ALLOCATIONS = "adaptive_allocations"; private final int numAllocations; private final int numThreads; private final String modelId; - - public InternalServiceSettings(int numAllocations, int numThreads, String modelId) { + private final AdaptiveAllocationsSettings adaptiveAllocationsSettings; + + public InternalServiceSettings( + int numAllocations, + int numThreads, + String modelId, + AdaptiveAllocationsSettings adaptiveAllocationsSettings + ) { this.numAllocations = numAllocations; this.numThreads = numThreads; this.modelId = modelId; + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; } public int getNumAllocations() { @@ -43,16 +53,23 @@ public String getModelId() { return modelId; } + public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() { + return adaptiveAllocationsSettings; + } + public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; InternalServiceSettings that = (InternalServiceSettings) o; - return numAllocations == that.numAllocations && numThreads == that.numThreads && Objects.equals(modelId, that.modelId); + return numAllocations == that.numAllocations + && numThreads == that.numThreads + && Objects.equals(modelId, that.modelId) + && Objects.equals(adaptiveAllocationsSettings, that.adaptiveAllocationsSettings); } @Override public int hashCode() { - return Objects.hash(numAllocations, numThreads, modelId); + return Objects.hash(numAllocations, numThreads, modelId, adaptiveAllocationsSettings); } @Override @@ -67,6 +84,7 @@ public void addXContentFragment(XContentBuilder builder, Params params) throws I builder.field(NUM_ALLOCATIONS, getNumAllocations()); builder.field(NUM_THREADS, getNumThreads()); builder.field(MODEL_ID, getModelId()); + builder.field(ADAPTIVE_ALLOCATIONS, getAdaptiveAllocationsSettings()); } @Override @@ -84,12 +102,16 @@ public void writeTo(StreamOutput out) throws IOException { out.writeVInt(getNumAllocations()); out.writeVInt(getNumThreads()); out.writeString(getModelId()); + if (out.getTransportVersion().onOrAfter(TransportVersions.INFERENCE_ADAPTIVE_ALLOCATIONS)) { + out.writeOptionalWriteable(getAdaptiveAllocationsSettings()); + } } public abstract static class Builder { private int numAllocations; private int numThreads; private String modelId; + private AdaptiveAllocationsSettings adaptiveAllocationsSettings; public abstract InternalServiceSettings build(); @@ -105,6 +127,10 @@ public void setModelId(String modelId) { this.modelId = modelId; } + public void setAdaptiveAllocationsSettings(AdaptiveAllocationsSettings adaptiveAllocationsSettings) { + this.adaptiveAllocationsSettings = adaptiveAllocationsSettings; + } + public String getModelId() { return modelId; } @@ -116,5 +142,9 @@ public int getNumAllocations() { public int getNumThreads() { return numThreads; } + + public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() { + return adaptiveAllocationsSettings; + } } } diff --git a/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy b/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy index f21a46521a7f7..a39fcf53be7f3 100644 --- a/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy +++ b/x-pack/plugin/inference/src/main/plugin-metadata/plugin-security.policy @@ -8,12 +8,18 @@ grant { // required by: com.google.api.client.json.JsonParser#parseValue + // also required by AWS SDK for client configuration permission java.lang.RuntimePermission "accessDeclaredMembers"; + permission java.lang.RuntimePermission "getClassLoader"; + // required by: com.google.api.client.json.GenericJson# + // also by AWS SDK for Jackson's ObjectMapper permission java.lang.reflect.ReflectPermission "suppressAccessChecks"; + // required to add google certs to the gcs client trustore permission java.lang.RuntimePermission "setFactory"; // gcs client opens socket connections for to access repository - permission java.net.SocketPermission "*", "connect"; + // also, AWS Bedrock client opens socket connections and needs resolve for to access to resources + permission java.net.SocketPermission "*", "connect,resolve"; }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java new file mode 100644 index 0000000000000..87d3a82b4aae6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java @@ -0,0 +1,175 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.amazonbedrock; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockMockRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockActionCreatorTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private ThreadPool threadPool; + + @Before + public void init() throws Exception { + threadPool = createThreadPool(inferenceUtilityPool()); + } + + @After + public void shutdown() throws IOException { + terminate(threadPool); + } + + public void testEmbeddingsRequestAction() throws IOException { + var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); + var mockedFloatResults = List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.0123F, -0.0123F })); + var mockedResult = new InferenceTextEmbeddingFloatResults(mockedFloatResults); + try (var sender = new AmazonBedrockMockRequestSender()) { + sender.enqueue(mockedResult); + var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT); + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "test_id", + "test_region", + "test_model", + AmazonBedrockProvider.AMAZONTITAN, + null, + false, + null, + null, + null, + "accesskey", + "secretkey" + ); + var action = creator.create(model, Map.of()); + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); + + assertThat(sender.sendCount(), is(1)); + var sentInputs = sender.getInputs(); + assertThat(sentInputs.size(), is(1)); + assertThat(sentInputs.get(0), is("abc")); + } + } + + public void testEmbeddingsRequestAction_HandlesException() throws IOException { + var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); + var mockedResult = new ElasticsearchException("mock exception"); + try (var sender = new AmazonBedrockMockRequestSender()) { + sender.enqueue(mockedResult); + var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT); + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "test_id", + "test_region", + "test_model", + AmazonBedrockProvider.AMAZONTITAN, + "accesskey", + "secretkey" + ); + var action = creator.create(model, Map.of()); + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(sender.sendCount(), is(1)); + assertThat(sender.getInputs().size(), is(1)); + assertThat(thrownException.getMessage(), is("mock exception")); + } + } + + public void testCompletionRequestAction() throws IOException { + var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); + var mockedChatCompletionResults = List.of(new ChatCompletionResults.Result("test input string")); + var mockedResult = new ChatCompletionResults(mockedChatCompletionResults); + try (var sender = new AmazonBedrockMockRequestSender()) { + sender.enqueue(mockedResult); + var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT); + var model = AmazonBedrockChatCompletionModelTests.createModel( + "test_id", + "test_region", + "test_model", + AmazonBedrockProvider.AMAZONTITAN, + null, + null, + null, + null, + null, + "accesskey", + "secretkey" + ); + var action = creator.create(model, Map.of()); + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test input string")))); + + assertThat(sender.sendCount(), is(1)); + var sentInputs = sender.getInputs(); + assertThat(sentInputs.size(), is(1)); + assertThat(sentInputs.get(0), is("abc")); + } + } + + public void testChatCompletionRequestAction_HandlesException() throws IOException { + var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); + var mockedResult = new ElasticsearchException("mock exception"); + try (var sender = new AmazonBedrockMockRequestSender()) { + sender.enqueue(mockedResult); + var creator = new AmazonBedrockActionCreator(sender, serviceComponents, TIMEOUT); + var model = AmazonBedrockChatCompletionModelTests.createModel( + "test_id", + "test_region", + "test_model", + AmazonBedrockProvider.AMAZONTITAN, + null, + null, + null, + null, + null, + "accesskey", + "secretkey" + ); + var action = creator.create(model, Map.of()); + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(sender.sendCount(), is(1)); + assertThat(sender.getInputs().size(), is(1)); + assertThat(thrownException.getMessage(), is("mock exception")); + } + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java new file mode 100644 index 0000000000000..9326d39cb657c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockExecutorTests.java @@ -0,0 +1,172 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import com.amazonaws.services.bedrockruntime.model.ContentBlock; +import com.amazonaws.services.bedrockruntime.model.ConverseOutput; +import com.amazonaws.services.bedrockruntime.model.ConverseResult; +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; +import com.amazonaws.services.bedrockruntime.model.Message; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockTitanCompletionRequestEntity; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings.AmazonBedrockTitanEmbeddingsRequestEntity; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.completion.AmazonBedrockChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.embeddings.AmazonBedrockEmbeddingsResponseHandler; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests; + +import java.nio.CharBuffer; +import java.nio.charset.CharacterCodingException; +import java.nio.charset.Charset; +import java.util.List; + +import static org.elasticsearch.xpack.inference.common.TruncatorTests.createTruncator; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockExecutorTests extends ESTestCase { + public void testExecute_EmbeddingsRequest_ForAmazonTitan() throws CharacterCodingException { + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "accesskey", + "secretkey" + ); + var truncator = createTruncator(); + var truncatedInput = truncator.truncate(List.of("abc")); + var requestEntity = new AmazonBedrockTitanEmbeddingsRequestEntity("abc"); + var request = new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, model, requestEntity, null); + var responseHandler = new AmazonBedrockEmbeddingsResponseHandler(); + + var clientCache = new AmazonBedrockMockClientCache(null, getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT), null); + var listener = new PlainActionFuture(); + + var executor = new AmazonBedrockEmbeddingsExecutor(request, responseHandler, logger, () -> false, listener, clientCache); + executor.run(); + var result = listener.actionGet(new TimeValue(30000)); + assertNotNull(result); + assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.456F, 0.678F, 0.789F })))); + } + + public void testExecute_EmbeddingsRequest_ForCohere() throws CharacterCodingException { + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.COHERE, + "accesskey", + "secretkey" + ); + var requestEntity = new AmazonBedrockTitanEmbeddingsRequestEntity("abc"); + var truncator = createTruncator(); + var truncatedInput = truncator.truncate(List.of("abc")); + var request = new AmazonBedrockEmbeddingsRequest(truncator, truncatedInput, model, requestEntity, null); + var responseHandler = new AmazonBedrockEmbeddingsResponseHandler(); + + var clientCache = new AmazonBedrockMockClientCache(null, getTestInvokeResult(TEST_COHERE_EMBEDDINGS_RESULT), null); + var listener = new PlainActionFuture(); + + var executor = new AmazonBedrockEmbeddingsExecutor(request, responseHandler, logger, () -> false, listener, clientCache); + executor.run(); + var result = listener.actionGet(new TimeValue(30000)); + assertNotNull(result); + assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.456F, 0.678F, 0.789F })))); + } + + public void testExecute_ChatCompletionRequest() throws CharacterCodingException { + var model = AmazonBedrockChatCompletionModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "accesskey", + "secretkey" + ); + + var requestEntity = new AmazonBedrockTitanCompletionRequestEntity(List.of("abc"), null, null, 512); + var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, null); + var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); + + var clientCache = new AmazonBedrockMockClientCache(getTestConverseResult("converse result"), null, null); + var listener = new PlainActionFuture(); + + var executor = new AmazonBedrockChatCompletionExecutor(request, responseHandler, logger, () -> false, listener, clientCache); + executor.run(); + var result = listener.actionGet(new TimeValue(30000)); + assertNotNull(result); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("converse result")))); + } + + public void testExecute_FailsProperly_WithElasticsearchException() { + var model = AmazonBedrockChatCompletionModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "accesskey", + "secretkey" + ); + + var requestEntity = new AmazonBedrockTitanCompletionRequestEntity(List.of("abc"), null, null, 512); + var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, null); + var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); + + var clientCache = new AmazonBedrockMockClientCache(null, null, new ElasticsearchException("test exception")); + var listener = new PlainActionFuture(); + + var executor = new AmazonBedrockChatCompletionExecutor(request, responseHandler, logger, () -> false, listener, clientCache); + executor.run(); + + var exceptionThrown = assertThrows(ElasticsearchException.class, () -> listener.actionGet(new TimeValue(30000))); + assertThat(exceptionThrown.getMessage(), containsString("Failed to send request from inference entity id [id]")); + assertThat(exceptionThrown.getCause().getMessage(), containsString("test exception")); + } + + public static ConverseResult getTestConverseResult(String resultText) { + var message = new Message().withContent(new ContentBlock().withText(resultText)); + var converseOutput = new ConverseOutput().withMessage(message); + return new ConverseResult().withOutput(converseOutput); + } + + public static InvokeModelResult getTestInvokeResult(String resultJson) throws CharacterCodingException { + var result = new InvokeModelResult(); + result.setContentType("application/json"); + var encoder = Charset.forName("UTF-8").newEncoder(); + result.setBody(encoder.encode(CharBuffer.wrap(resultJson))); + return result; + } + + public static final String TEST_AMAZON_TITAN_EMBEDDINGS_RESULT = """ + { + "embedding": [0.123, 0.456, 0.678, 0.789], + "inputTextTokenCount": int + }"""; + + public static final String TEST_COHERE_EMBEDDINGS_RESULT = """ + { + "embeddings": [ + [0.123, 0.456, 0.678, 0.789] + ], + "id": string, + "response_type" : "embeddings_floats", + "texts": [string] + } + """; +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCacheTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCacheTests.java new file mode 100644 index 0000000000000..873b2e22497c6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockInferenceClientCacheTests.java @@ -0,0 +1,108 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests; + +import java.io.IOException; +import java.time.Clock; +import java.time.Duration; +import java.time.Instant; +import java.time.ZoneId; + +import static org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockInferenceClient.CLIENT_CACHE_EXPIRY_MINUTES; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.sameInstance; + +public class AmazonBedrockInferenceClientCacheTests extends ESTestCase { + public void testCache_ReturnsSameObject() throws IOException { + AmazonBedrockInferenceClientCache cacheInstance; + try (var cache = new AmazonBedrockInferenceClientCache(AmazonBedrockMockInferenceClient::create, null)) { + cacheInstance = cache; + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "inferenceId", + "testregion", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "access_key", + "secret_key" + ); + + var client = cache.getOrCreateClient(model, null); + + var secondModel = AmazonBedrockEmbeddingsModelTests.createModel( + "inferenceId_two", + "testregion", + "a_different_model", + AmazonBedrockProvider.COHERE, + "access_key", + "secret_key" + ); + + var secondClient = cache.getOrCreateClient(secondModel, null); + assertThat(client, sameInstance(secondClient)); + + assertThat(cache.clientCount(), is(1)); + + var thirdClient = cache.getOrCreateClient(model, null); + assertThat(client, sameInstance(thirdClient)); + + assertThat(cache.clientCount(), is(1)); + } + assertThat(cacheInstance.clientCount(), is(0)); + } + + public void testCache_ItEvictsExpiredClients() throws IOException { + var clock = Clock.fixed(Instant.now(), ZoneId.systemDefault()); + AmazonBedrockInferenceClientCache cacheInstance; + try (var cache = new AmazonBedrockInferenceClientCache(AmazonBedrockMockInferenceClient::create, clock)) { + cacheInstance = cache; + + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "inferenceId", + "testregion", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "access_key", + "secret_key" + ); + + var client = cache.getOrCreateClient(model, null); + + var secondModel = AmazonBedrockEmbeddingsModelTests.createModel( + "inferenceId_two", + "some_other_region", + "a_different_model", + AmazonBedrockProvider.COHERE, + "other_access_key", + "other_secret_key" + ); + + assertThat(cache.clientCount(), is(1)); + + var secondClient = cache.getOrCreateClient(secondModel, null); + assertThat(client, not(sameInstance(secondClient))); + + assertThat(cache.clientCount(), is(2)); + + // set clock to after expiry + cache.setClock(Clock.fixed(clock.instant().plus(Duration.ofMinutes(CLIENT_CACHE_EXPIRY_MINUTES + 1)), ZoneId.systemDefault())); + + // get another client, this will ensure flushExpiredClients is called + var regetSecondClient = cache.getOrCreateClient(secondModel, null); + assertThat(secondClient, sameInstance(regetSecondClient)); + + var regetFirstClient = cache.getOrCreateClient(model, null); + assertThat(client, not(sameInstance(regetFirstClient))); + } + assertThat(cacheInstance.clientCount(), is(0)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockClientCache.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockClientCache.java new file mode 100644 index 0000000000000..912967a9012d7 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockClientCache.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import com.amazonaws.services.bedrockruntime.model.ConverseResult; +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; + +import java.io.IOException; + +public class AmazonBedrockMockClientCache implements AmazonBedrockClientCache { + private ConverseResult converseResult = null; + private InvokeModelResult invokeModelResult = null; + private ElasticsearchException exceptionToThrow = null; + + public AmazonBedrockMockClientCache() {} + + public AmazonBedrockMockClientCache( + @Nullable ConverseResult converseResult, + @Nullable InvokeModelResult invokeModelResult, + @Nullable ElasticsearchException exceptionToThrow + ) { + this.converseResult = converseResult; + this.invokeModelResult = invokeModelResult; + this.exceptionToThrow = exceptionToThrow; + } + + @Override + public AmazonBedrockBaseClient getOrCreateClient(AmazonBedrockModel model, TimeValue timeout) { + var client = (AmazonBedrockMockInferenceClient) AmazonBedrockMockInferenceClient.create(model, timeout); + client.setConverseResult(converseResult); + client.setInvokeModelResult(invokeModelResult); + client.setExceptionToThrow(exceptionToThrow); + return client; + } + + @Override + public void close() throws IOException { + // nothing to do + } + + public void setConverseResult(ConverseResult converseResult) { + this.converseResult = converseResult; + } + + public void setInvokeModelResult(InvokeModelResult invokeModelResult) { + this.invokeModelResult = invokeModelResult; + } + + public void setExceptionToThrow(ElasticsearchException exceptionToThrow) { + this.exceptionToThrow = exceptionToThrow; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockExecuteRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockExecuteRequestSender.java new file mode 100644 index 0000000000000..b0df8a40e2551 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockExecuteRequestSender.java @@ -0,0 +1,80 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import com.amazonaws.services.bedrockruntime.model.ConverseResult; +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockRequest; +import org.elasticsearch.xpack.inference.external.response.amazonbedrock.AmazonBedrockResponseHandler; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; + +import java.util.List; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.function.Supplier; + +public class AmazonBedrockMockExecuteRequestSender extends AmazonBedrockExecuteOnlyRequestSender { + + private Queue results = new ConcurrentLinkedQueue<>(); + private Queue> inputs = new ConcurrentLinkedQueue<>(); + private int sendCounter = 0; + + public AmazonBedrockMockExecuteRequestSender(AmazonBedrockClientCache clientCache, ThrottlerManager throttlerManager) { + super(clientCache, throttlerManager); + } + + public void enqueue(Object result) { + results.add(result); + } + + public int sendCount() { + return sendCounter; + } + + public List getInputs() { + return inputs.remove(); + } + + @Override + protected AmazonBedrockExecutor createExecutor( + AmazonBedrockRequest awsRequest, + AmazonBedrockResponseHandler awsResponse, + Logger logger, + Supplier hasRequestTimedOutFunction, + ActionListener listener + ) { + setCacheResult(); + return super.createExecutor(awsRequest, awsResponse, logger, hasRequestTimedOutFunction, listener); + } + + private void setCacheResult() { + var mockCache = (AmazonBedrockMockClientCache) this.clientCache; + var result = results.remove(); + if (result instanceof ConverseResult converseResult) { + mockCache.setConverseResult(converseResult); + return; + } + + if (result instanceof InvokeModelResult invokeModelResult) { + mockCache.setInvokeModelResult(invokeModelResult); + return; + } + + if (result instanceof ElasticsearchException exception) { + mockCache.setExceptionToThrow(exception); + return; + } + + throw new RuntimeException("Unknown result type: " + result.getClass()); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockInferenceClient.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockInferenceClient.java new file mode 100644 index 0000000000000..dcbf8dfcbff01 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockInferenceClient.java @@ -0,0 +1,133 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import com.amazonaws.services.bedrockruntime.AmazonBedrockRuntimeAsync; +import com.amazonaws.services.bedrockruntime.model.ConverseResult; +import com.amazonaws.services.bedrockruntime.model.InvokeModelResult; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockModel; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; + +public class AmazonBedrockMockInferenceClient extends AmazonBedrockInferenceClient { + private ConverseResult converseResult = null; + private InvokeModelResult invokeModelResult = null; + private ElasticsearchException exceptionToThrow = null; + + private Future converseResultFuture = new MockConverseResultFuture(); + private Future invokeModelResultFuture = new MockInvokeResultFuture(); + + public static AmazonBedrockBaseClient create(AmazonBedrockModel model, @Nullable TimeValue timeout) { + return new AmazonBedrockMockInferenceClient(model, timeout); + } + + protected AmazonBedrockMockInferenceClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + super(model, timeout); + } + + public void setExceptionToThrow(ElasticsearchException exceptionToThrow) { + this.exceptionToThrow = exceptionToThrow; + } + + public void setConverseResult(ConverseResult result) { + this.converseResult = result; + } + + public void setInvokeModelResult(InvokeModelResult result) { + this.invokeModelResult = result; + } + + @Override + protected AmazonBedrockRuntimeAsync createAmazonBedrockClient(AmazonBedrockModel model, @Nullable TimeValue timeout) { + var runtimeClient = mock(AmazonBedrockRuntimeAsync.class); + doAnswer(invocation -> invokeModelResultFuture).when(runtimeClient).invokeModelAsync(any()); + doAnswer(invocation -> converseResultFuture).when(runtimeClient).converseAsync(any()); + + return runtimeClient; + } + + @Override + void close() {} + + private class MockConverseResultFuture implements Future { + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return false; + } + + @Override + public ConverseResult get() throws InterruptedException, ExecutionException { + if (exceptionToThrow != null) { + throw exceptionToThrow; + } + return converseResult; + } + + @Override + public ConverseResult get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + if (exceptionToThrow != null) { + throw exceptionToThrow; + } + return converseResult; + } + } + + private class MockInvokeResultFuture implements Future { + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return false; + } + + @Override + public InvokeModelResult get() throws InterruptedException, ExecutionException { + if (exceptionToThrow != null) { + throw exceptionToThrow; + } + return invokeModelResult; + } + + @Override + public InvokeModelResult get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + if (exceptionToThrow != null) { + throw exceptionToThrow; + } + return invokeModelResult; + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java new file mode 100644 index 0000000000000..e68beaf4c1eb5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java @@ -0,0 +1,91 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; + +import java.io.IOException; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; + +public class AmazonBedrockMockRequestSender implements Sender { + + public static class Factory extends AmazonBedrockRequestSender.Factory { + private final Sender sender; + + public Factory(ServiceComponents serviceComponents, ClusterService clusterService) { + super(serviceComponents, clusterService); + this.sender = new AmazonBedrockMockRequestSender(); + } + + public Sender createSender() { + return sender; + } + } + + private Queue results = new ConcurrentLinkedQueue<>(); + private Queue> inputs = new ConcurrentLinkedQueue<>(); + private int sendCounter = 0; + + public void enqueue(Object result) { + results.add(result); + } + + public int sendCount() { + return sendCounter; + } + + public List getInputs() { + return inputs.remove(); + } + + @Override + public void start() { + // do nothing + } + + @Override + public void send( + RequestManager requestCreator, + InferenceInputs inferenceInputs, + TimeValue timeout, + ActionListener listener + ) { + sendCounter++; + var docsInput = (DocumentsOnlyInput) inferenceInputs; + inputs.add(docsInput.getInputs()); + + if (results.isEmpty()) { + listener.onFailure(new ElasticsearchException("No results found")); + } else { + var resultObject = results.remove(); + if (resultObject instanceof InferenceServiceResults inferenceResult) { + listener.onResponse(inferenceResult); + } else if (resultObject instanceof Exception e) { + listener.onFailure(e); + } else { + throw new RuntimeException("Unknown result type: " + resultObject.getClass()); + } + } + } + + @Override + public void close() throws IOException { + // do nothing + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java new file mode 100644 index 0000000000000..7fa8a09d5bf12 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java @@ -0,0 +1,127 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.amazonbedrock; + +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockChatCompletionRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockExecutorTests.TEST_AMAZON_TITAN_EMBEDDINGS_RESULT; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class AmazonBedrockRequestSenderTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private ThreadPool threadPool; + private final AtomicReference threadRef = new AtomicReference<>(); + + @Before + public void init() throws Exception { + threadPool = createThreadPool(inferenceUtilityPool()); + threadRef.set(null); + } + + @After + public void shutdown() throws IOException, InterruptedException { + if (threadRef.get() != null) { + threadRef.get().join(TIMEOUT.millis()); + } + + terminate(threadPool); + } + + public void testCreateSender_SendsEmbeddingsRequestAndReceivesResponse() throws Exception { + var senderFactory = createSenderFactory(threadPool, Settings.EMPTY); + var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class)); + requestSender.enqueue(AmazonBedrockExecutorTests.getTestInvokeResult(TEST_AMAZON_TITAN_EMBEDDINGS_RESULT)); + try (var sender = createSender(senderFactory, requestSender)) { + sender.start(); + + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "test_id", + "test_region", + "test_model", + AmazonBedrockProvider.AMAZONTITAN, + "accesskey", + "secretkey" + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + var serviceComponents = ServiceComponentsTests.createWithEmptySettings(threadPool); + var requestManager = new AmazonBedrockEmbeddingsRequestManager( + model, + serviceComponents.truncator(), + threadPool, + new TimeValue(30, TimeUnit.SECONDS) + ); + sender.send(requestManager, new DocumentsOnlyInput(List.of("abc")), null, listener); + + var result = listener.actionGet(TIMEOUT); + assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.456F, 0.678F, 0.789F })))); + } + } + + public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws Exception { + var senderFactory = createSenderFactory(threadPool, Settings.EMPTY); + var requestSender = new AmazonBedrockMockExecuteRequestSender(new AmazonBedrockMockClientCache(), mock(ThrottlerManager.class)); + requestSender.enqueue(AmazonBedrockExecutorTests.getTestConverseResult("test response text")); + try (var sender = createSender(senderFactory, requestSender)) { + sender.start(); + + var model = AmazonBedrockChatCompletionModelTests.createModel( + "test_id", + "test_region", + "test_model", + AmazonBedrockProvider.AMAZONTITAN, + "accesskey", + "secretkey" + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + var requestManager = new AmazonBedrockChatCompletionRequestManager(model, threadPool, new TimeValue(30, TimeUnit.SECONDS)); + sender.send(requestManager, new DocumentsOnlyInput(List.of("abc")), null, listener); + + var result = listener.actionGet(TIMEOUT); + assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test response text")))); + } + } + + public static AmazonBedrockRequestSender.Factory createSenderFactory(ThreadPool threadPool, Settings settings) { + return new AmazonBedrockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, settings), + mockClusterServiceEmpty() + ); + } + + public static Sender createSender(AmazonBedrockRequestSender.Factory factory, AmazonBedrockExecuteOnlyRequestSender requestSender) { + return factory.createSender(requestSender); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManagerTests.java index 03838896b879d..bf120be621ad3 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/BaseRequestManagerTests.java @@ -14,7 +14,6 @@ import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; -import java.util.List; import java.util.concurrent.TimeUnit; import java.util.function.Supplier; @@ -30,8 +29,7 @@ public void testRateLimitGrouping_DifferentObjectReferences_HaveSameGroup() { var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1)) { @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener @@ -43,8 +41,7 @@ public void execute( var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val2, new RateLimitSettings(1)) { @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener @@ -62,8 +59,7 @@ public void testRateLimitGrouping_DifferentSettings_HaveDifferentGroup() { var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1)) { @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener @@ -75,8 +71,7 @@ public void execute( var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(2)) { @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener @@ -94,8 +89,7 @@ public void testRateLimitGrouping_DifferentSettingsTimeUnit_HaveDifferentGroup() var manager1 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1, TimeUnit.MILLISECONDS)) { @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener @@ -107,8 +101,7 @@ public void execute( var manager2 = new BaseRequestManager(mock(ThreadPool.class), "id", val1, new RateLimitSettings(1, TimeUnit.DAYS)) { @Override public void execute( - String query, - List input, + InferenceInputs inferenceInputs, RequestSender requestSender, Supplier hasRequestCompletedFunction, ActionListener listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java index 2b8b5f178b3de..79f6aa8164b75 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/HttpRequestSenderTests.java @@ -106,7 +106,7 @@ public void testCreateSender_SendsRequestAndReceivesResponse() throws Exception PlainActionFuture listener = new PlainActionFuture<>(); sender.send( - OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator(getUrl(webServer), null, "key", "model", null, threadPool), + OpenAiEmbeddingsRequestManagerTests.makeCreator(getUrl(webServer), null, "key", "model", null, threadPool), new DocumentsOnlyInput(List.of("abc")), null, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsExecutableRequestCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManagerTests.java similarity index 95% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsExecutableRequestCreatorTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManagerTests.java index 37fce8d3f3a7b..eb7f7c4a0035d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsExecutableRequestCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiEmbeddingsRequestManagerTests.java @@ -13,7 +13,7 @@ import static org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsModelTests.createModel; -public class OpenAiEmbeddingsExecutableRequestCreatorTests { +public class OpenAiEmbeddingsRequestManagerTests { public static OpenAiEmbeddingsRequestManager makeCreator( String url, @Nullable String org, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java index 9a45e10007643..762a3a74184a4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestExecutorServiceTests.java @@ -131,7 +131,7 @@ public void testIsTerminated_AfterStopFromSeparateThread() { PlainActionFuture listener = new PlainActionFuture<>(); service.execute( - OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator("url", null, "key", "id", null, threadPool), + OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "id", null, threadPool), new DocumentsOnlyInput(List.of()), null, listener @@ -208,7 +208,7 @@ public void testTaskThrowsError_CallsOnFailure() { PlainActionFuture listener = new PlainActionFuture<>(); service.execute( - OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator("url", null, "key", "id", null, threadPool), + OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "id", null, threadPool), new DocumentsOnlyInput(List.of()), null, listener diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java index 291de740aca34..8b7c01ae133cf 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestManagerTests.java @@ -17,7 +17,6 @@ import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyList; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; @@ -44,7 +43,7 @@ public static RequestManager createMock(RequestSender requestSender, String infe doAnswer(invocation -> { @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocation.getArguments()[4]; + ActionListener listener = (ActionListener) invocation.getArguments()[3]; requestSender.send( mock(Logger.class), RequestTests.mockRequest(inferenceEntityId), @@ -55,7 +54,7 @@ public static RequestManager createMock(RequestSender requestSender, String infe ); return Void.TYPE; - }).when(mockManager).execute(any(), anyList(), any(), any(), any()); + }).when(mockManager).execute(any(), any(), any(), any()); // just return something consistent so the hashing works when(mockManager.rateLimitGrouping()).thenReturn(inferenceEntityId); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java index 13c395180cd16..c839c266e9320 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/RequestTaskTests.java @@ -59,7 +59,7 @@ public void testExecuting_DoesNotCallOnFailureForTimeout_AfterIllegalArgumentExc ActionListener listener = mock(ActionListener.class); var requestTask = new RequestTask( - OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator("url", null, "key", "model", null, "id", threadPool), + OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), new DocumentsOnlyInput(List.of("abc")), TimeValue.timeValueMillis(1), mockThreadPool, @@ -79,7 +79,7 @@ public void testRequest_ReturnsTimeoutException() { PlainActionFuture listener = new PlainActionFuture<>(); var requestTask = new RequestTask( - OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator("url", null, "key", "model", null, "id", threadPool), + OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), new DocumentsOnlyInput(List.of("abc")), TimeValue.timeValueMillis(1), threadPool, @@ -105,7 +105,7 @@ public void testRequest_DoesNotCallOnFailureTwiceWhenTimingOut() throws Exceptio }).when(listener).onFailure(any()); var requestTask = new RequestTask( - OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator("url", null, "key", "model", null, "id", threadPool), + OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), new DocumentsOnlyInput(List.of("abc")), TimeValue.timeValueMillis(1), threadPool, @@ -137,7 +137,7 @@ public void testRequest_DoesNotCallOnResponseAfterTimingOut() throws Exception { }).when(listener).onFailure(any()); var requestTask = new RequestTask( - OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator("url", null, "key", "model", null, "id", threadPool), + OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), new DocumentsOnlyInput(List.of("abc")), TimeValue.timeValueMillis(1), threadPool, @@ -167,7 +167,7 @@ public void testRequest_DoesNotCallOnFailureForTimeout_AfterAlreadyCallingOnResp ActionListener listener = mock(ActionListener.class); var requestTask = new RequestTask( - OpenAiEmbeddingsExecutableRequestCreatorTests.makeCreator("url", null, "key", "model", null, "id", threadPool), + OpenAiEmbeddingsRequestManagerTests.makeCreator("url", null, "key", "model", null, "id", threadPool), new DocumentsOnlyInput(List.of("abc")), TimeValue.timeValueMillis(1), mockThreadPool, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..b91aab5410048 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAI21LabsCompletionRequestEntityTests.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockAI21LabsCompletionRequestEntityTests extends ESTestCase { + public void testRequestEntity_CreatesProperRequest() { + var request = new AmazonBedrockAI21LabsCompletionRequestEntity(List.of("test message"), null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTemperature() { + var request = new AmazonBedrockAI21LabsCompletionRequestEntity(List.of("test message"), 1.0, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopP() { + var request = new AmazonBedrockAI21LabsCompletionRequestEntity(List.of("test message"), null, 1.0, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { + var request = new AmazonBedrockAI21LabsCompletionRequestEntity(List.of("test message"), null, null, 128); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..89d5fec7efba6 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockAnthropicCompletionRequestEntityTests.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockAnthropicCompletionRequestEntityTests extends ESTestCase { + public void testRequestEntity_CreatesProperRequest() { + var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), null, null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTemperature() { + var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), 1.0, null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopP() { + var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), null, 1.0, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { + var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), null, null, null, 128); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopK() { + var request = new AmazonBedrockAnthropicCompletionRequestEntity(List.of("test message"), null, null, 1.0, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopKInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..8df5c7f32e529 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockCohereCompletionRequestEntityTests.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockCohereCompletionRequestEntityTests extends ESTestCase { + public void testRequestEntity_CreatesProperRequest() { + var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), null, null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTemperature() { + var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), 1.0, null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopP() { + var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), null, 1.0, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { + var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), null, null, null, 128); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopK() { + var request = new AmazonBedrockCohereCompletionRequestEntity(List.of("test message"), null, null, 1.0, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopKInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestUtils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestUtils.java new file mode 100644 index 0000000000000..cbbe3c5554967 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockConverseRequestUtils.java @@ -0,0 +1,94 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import com.amazonaws.services.bedrockruntime.model.ContentBlock; +import com.amazonaws.services.bedrockruntime.model.ConverseRequest; +import com.amazonaws.services.bedrockruntime.model.Message; + +import org.elasticsearch.core.Strings; + +public final class AmazonBedrockConverseRequestUtils { + public static ConverseRequest getConverseRequest(String modelId, AmazonBedrockConverseRequestEntity requestEntity) { + var converseRequest = new ConverseRequest().withModelId(modelId); + converseRequest = requestEntity.addMessages(converseRequest); + converseRequest = requestEntity.addInferenceConfig(converseRequest); + converseRequest = requestEntity.addAdditionalModelFields(converseRequest); + return converseRequest; + } + + public static boolean doesConverseRequestHasMessage(ConverseRequest converseRequest, String expectedMessage) { + for (Message message : converseRequest.getMessages()) { + var content = message.getContent(); + for (ContentBlock contentBlock : content) { + if (contentBlock.getText().equals(expectedMessage)) { + return true; + } + } + } + return false; + } + + public static boolean doesConverseRequestHaveAnyTemperatureInput(ConverseRequest converseRequest) { + return converseRequest.getInferenceConfig() != null + && converseRequest.getInferenceConfig().getTemperature() != null + && (converseRequest.getInferenceConfig().getTemperature().isNaN() == false); + } + + public static boolean doesConverseRequestHaveAnyTopPInput(ConverseRequest converseRequest) { + return converseRequest.getInferenceConfig() != null + && converseRequest.getInferenceConfig().getTopP() != null + && (converseRequest.getInferenceConfig().getTopP().isNaN() == false); + } + + public static boolean doesConverseRequestHaveAnyMaxTokensInput(ConverseRequest converseRequest) { + return converseRequest.getInferenceConfig() != null && converseRequest.getInferenceConfig().getMaxTokens() != null; + } + + public static boolean doesConverseRequestHaveTemperatureInput(ConverseRequest converseRequest, Double temperature) { + return doesConverseRequestHaveAnyTemperatureInput(converseRequest) + && converseRequest.getInferenceConfig().getTemperature().equals(temperature.floatValue()); + } + + public static boolean doesConverseRequestHaveTopPInput(ConverseRequest converseRequest, Double topP) { + return doesConverseRequestHaveAnyTopPInput(converseRequest) + && converseRequest.getInferenceConfig().getTopP().equals(topP.floatValue()); + } + + public static boolean doesConverseRequestHaveMaxTokensInput(ConverseRequest converseRequest, Integer maxTokens) { + return doesConverseRequestHaveAnyMaxTokensInput(converseRequest) + && converseRequest.getInferenceConfig().getMaxTokens().equals(maxTokens); + } + + public static boolean doesConverseRequestHaveAnyTopKInput(ConverseRequest converseRequest) { + if (converseRequest.getAdditionalModelResponseFieldPaths() == null) { + return false; + } + + for (String fieldPath : converseRequest.getAdditionalModelResponseFieldPaths()) { + if (fieldPath.contains("{\"top_k\":")) { + return true; + } + } + return false; + } + + public static boolean doesConverseRequestHaveTopKInput(ConverseRequest converseRequest, Double topK) { + if (doesConverseRequestHaveAnyTopKInput(converseRequest) == false) { + return false; + } + + var checkString = Strings.format("{\"top_k\":%f}", topK.floatValue()); + for (String fieldPath : converseRequest.getAdditionalModelResponseFieldPaths()) { + if (fieldPath.contains(checkString)) { + return true; + } + } + return false; + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..fa482669a0bb2 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMetaCompletionRequestEntityTests.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockMetaCompletionRequestEntityTests extends ESTestCase { + public void testRequestEntity_CreatesProperRequest() { + var request = new AmazonBedrockMetaCompletionRequestEntity(List.of("test message"), null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTemperature() { + var request = new AmazonBedrockMetaCompletionRequestEntity(List.of("test message"), 1.0, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopP() { + var request = new AmazonBedrockMetaCompletionRequestEntity(List.of("test message"), null, 1.0, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { + var request = new AmazonBedrockMetaCompletionRequestEntity(List.of("test message"), null, null, 128); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..788625d3702b8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockMistralCompletionRequestEntityTests.java @@ -0,0 +1,82 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockMistralCompletionRequestEntityTests extends ESTestCase { + public void testRequestEntity_CreatesProperRequest() { + var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), null, null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTemperature() { + var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), 1.0, null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopP() { + var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), null, 1.0, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { + var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), null, null, null, 128); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopK() { + var request = new AmazonBedrockMistralCompletionRequestEntity(List.of("test message"), null, null, 1.0, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopKInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntityTests.java new file mode 100644 index 0000000000000..79fa387876c8b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/completion/AmazonBedrockTitanCompletionRequestEntityTests.java @@ -0,0 +1,70 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion; + +import org.elasticsearch.test.ESTestCase; + +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHasMessage; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopKInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveAnyTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveMaxTokensInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTemperatureInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.doesConverseRequestHaveTopPInput; +import static org.elasticsearch.xpack.inference.external.request.amazonbedrock.completion.AmazonBedrockConverseRequestUtils.getConverseRequest; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockTitanCompletionRequestEntityTests extends ESTestCase { + public void testRequestEntity_CreatesProperRequest() { + var request = new AmazonBedrockTitanCompletionRequestEntity(List.of("test message"), null, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTemperature() { + var request = new AmazonBedrockTitanCompletionRequestEntity(List.of("test message"), 1.0, null, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertTrue(doesConverseRequestHaveTemperatureInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithTopP() { + var request = new AmazonBedrockTitanCompletionRequestEntity(List.of("test message"), null, 1.0, null); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertTrue(doesConverseRequestHaveTopPInput(builtRequest, 1.0)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyMaxTokensInput(builtRequest)); + } + + public void testRequestEntity_CreatesProperRequest_WithMaxTokens() { + var request = new AmazonBedrockTitanCompletionRequestEntity(List.of("test message"), null, null, 128); + var builtRequest = getConverseRequest("testmodel", request); + assertThat(builtRequest.getModelId(), is("testmodel")); + assertThat(doesConverseRequestHasMessage(builtRequest, "test message"), is(true)); + assertFalse(doesConverseRequestHaveAnyTemperatureInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopPInput(builtRequest)); + assertFalse(doesConverseRequestHaveAnyTopKInput(builtRequest)); + assertTrue(doesConverseRequestHaveMaxTokensInput(builtRequest, 128)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..fd8114f889d6a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockCohereEmbeddingsRequestEntityTests.java @@ -0,0 +1,25 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockJsonBuilder; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockCohereEmbeddingsRequestEntityTests extends ESTestCase { + public void testRequestEntity_GeneratesExpectedJsonBody() throws IOException { + var entity = new AmazonBedrockCohereEmbeddingsRequestEntity(List.of("test input")); + var builder = new AmazonBedrockJsonBuilder(entity); + var result = builder.getStringContent(); + assertThat(result, is("{\"texts\":[\"test input\"],\"input_type\":\"search_document\"}")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockTitanEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockTitanEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..da98fa251fdc8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/amazonbedrock/embeddings/AmazonBedrockTitanEmbeddingsRequestEntityTests.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.amazonbedrock.embeddings; + +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.external.request.amazonbedrock.AmazonBedrockJsonBuilder; + +import java.io.IOException; + +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockTitanEmbeddingsRequestEntityTests extends ESTestCase { + public void testRequestEntity_GeneratesExpectedJsonBody() throws IOException { + var entity = new AmazonBedrockTitanEmbeddingsRequestEntity("test input"); + var builder = new AmazonBedrockJsonBuilder(entity); + var result = builder.getStringContent(); + assertThat(result, is("{\"inputText\":\"test input\"}")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java index 8cb9305edd057..7fbfe70dbcfe7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rank/textsimilarity/TextSimilarityRankTests.java @@ -151,7 +151,6 @@ public void testRerankInferenceFailure() { ); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/110398") public void testRerankInferenceResultMismatch() { ElasticsearchAssertions.assertFailures( // Execute search with text similarity reranking diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettingsTests.java new file mode 100644 index 0000000000000..904851842a6c8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockSecretSettingsTests.java @@ -0,0 +1,120 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.ACCESS_KEY_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.SECRET_KEY_FIELD; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockSecretSettingsTests extends AbstractBWCWireSerializationTestCase { + + public void testIt_CreatesSettings_ReturnsNullFromMap_null() { + var secrets = AmazonBedrockSecretSettings.fromMap(null); + assertNull(secrets); + } + + public void testIt_CreatesSettings_FromMap_WithValues() { + var secrets = AmazonBedrockSecretSettings.fromMap( + new HashMap<>(Map.of(ACCESS_KEY_FIELD, "accesstest", SECRET_KEY_FIELD, "secrettest")) + ); + assertThat( + secrets, + is(new AmazonBedrockSecretSettings(new SecureString("accesstest".toCharArray()), new SecureString("secrettest".toCharArray()))) + ); + } + + public void testIt_CreatesSettings_FromMap_IgnoresExtraKeys() { + var secrets = AmazonBedrockSecretSettings.fromMap( + new HashMap<>(Map.of(ACCESS_KEY_FIELD, "accesstest", SECRET_KEY_FIELD, "secrettest", "extrakey", "extravalue")) + ); + assertThat( + secrets, + is(new AmazonBedrockSecretSettings(new SecureString("accesstest".toCharArray()), new SecureString("secrettest".toCharArray()))) + ); + } + + public void testIt_FromMap_ThrowsValidationException_AccessKeyMissing() { + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockSecretSettings.fromMap(new HashMap<>(Map.of(SECRET_KEY_FIELD, "secrettest"))) + ); + + assertThat( + thrownException.getMessage(), + containsString(Strings.format("[secret_settings] does not contain the required setting [%s]", ACCESS_KEY_FIELD)) + ); + } + + public void testIt_FromMap_ThrowsValidationException_SecretKeyMissing() { + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockSecretSettings.fromMap(new HashMap<>(Map.of(ACCESS_KEY_FIELD, "accesstest"))) + ); + + assertThat( + thrownException.getMessage(), + containsString(Strings.format("[secret_settings] does not contain the required setting [%s]", SECRET_KEY_FIELD)) + ); + } + + public void testToXContent_CreatesProperContent() throws IOException { + var secrets = AmazonBedrockSecretSettings.fromMap( + new HashMap<>(Map.of(ACCESS_KEY_FIELD, "accesstest", SECRET_KEY_FIELD, "secrettest")) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + secrets.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + assertThat(xContentResult, CoreMatchers.is(""" + {"access_key":"accesstest","secret_key":"secrettest"}""")); + } + + public static Map getAmazonBedrockSecretSettingsMap(String accessKey, String secretKey) { + return new HashMap(Map.of(ACCESS_KEY_FIELD, accessKey, SECRET_KEY_FIELD, secretKey)); + } + + @Override + protected AmazonBedrockSecretSettings mutateInstanceForVersion(AmazonBedrockSecretSettings instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return AmazonBedrockSecretSettings::new; + } + + @Override + protected AmazonBedrockSecretSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AmazonBedrockSecretSettings mutateInstance(AmazonBedrockSecretSettings instance) throws IOException { + return randomValueOtherThan(instance, AmazonBedrockSecretSettingsTests::createRandom); + } + + private static AmazonBedrockSecretSettings createRandom() { + return new AmazonBedrockSecretSettings(new SecureString(randomAlphaOfLength(10)), new SecureString(randomAlphaOfLength(10))); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java new file mode 100644 index 0000000000000..ae413fc17425c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -0,0 +1,1136 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.Utils; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockMockRequestSender; +import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionModelTests; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettings; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsModelTests; +import org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.results.ChatCompletionResultsTests.buildExpectationCompletion; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettingsTests.getAmazonBedrockSecretSettingsMap; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionServiceSettingsTests.createChatCompletionRequestSettingsMap; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettingsTests.createEmbeddingsRequestSettingsMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class AmazonBedrockServiceTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private ThreadPool threadPool; + + @Before + public void init() throws Exception { + threadPool = createThreadPool(inferenceUtilityPool()); + } + + @After + public void shutdown() throws IOException { + terminate(threadPool); + } + + public void testParseRequestConfig_CreatesAnAmazonBedrockModel() throws IOException { + try (var service = createAmazonBedrockService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + var secretSettings = (AmazonBedrockSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey.toString(), is("access")); + assertThat(secretSettings.secretKey.toString(), is("secret")); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, null, null, null), + Map.of(), + getAmazonBedrockSecretSettingsMap("access", "secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { + try (var service = createAmazonBedrockService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [amazonbedrock] service does not support task type [sparse_embedding]")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.SPARSE_EMBEDDING, + getRequestConfigMap( + createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null), + Map.of(), + getAmazonBedrockSecretSettingsMap("access", "secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testCreateModel_ForEmbeddingsTask_InvalidProvider() throws IOException { + try (var service = createAmazonBedrockService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [text_embedding] task type for provider [anthropic] is not available")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + createEmbeddingsRequestSettingsMap("region", "model", "anthropic", null, null, null, null), + Map.of(), + getAmazonBedrockSecretSettingsMap("access", "secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testCreateModel_TopKParameter_NotAvailable() throws IOException { + try (var service = createAmazonBedrockService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [top_k] task parameter is not available for provider [amazontitan]")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.COMPLETION, + getRequestConfigMap( + createChatCompletionRequestSettingsMap("region", "model", "amazontitan"), + getChatCompletionTaskSettingsMap(1.0, 0.5, 0.2, 128), + getAmazonBedrockSecretSettingsMap("access", "secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createAmazonBedrockService()) { + var config = getRequestConfigMap( + createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, null, null, null), + Map.of(), + getAmazonBedrockSecretSettingsMap("access", "secret") + ); + + config.put("extra_key", "value"); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") + ); + } + ); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInServiceSettingsMap() throws IOException { + try (var service = createAmazonBedrockService()) { + var serviceSettings = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, null, null, null); + serviceSettings.put("extra_key", "value"); + + var config = getRequestConfigMap(serviceSettings, Map.of(), getAmazonBedrockSecretSettingsMap("access", "secret")); + + ActionListener modelVerificationListener = ActionListener.wrap((model) -> { + fail("Expected exception, but got model: " + model); + }, e -> { + assertThat(e, instanceOf(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") + ); + }); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInTaskSettingsMap() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createChatCompletionRequestSettingsMap("region", "model", "anthropic"); + var taskSettingsMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.2, 128); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + taskSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); + + ActionListener modelVerificationListener = ActionListener.wrap((model) -> { + fail("Expected exception, but got model: " + model); + }, e -> { + assertThat(e, instanceOf(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") + ); + }); + + service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInSecretSettingsMap() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createChatCompletionRequestSettingsMap("region", "model", "anthropic"); + var taskSettingsMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.2, 128); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + secretSettingsMap.put("extra_key", "value"); + + var config = getRequestConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); + + ActionListener modelVerificationListener = ActionListener.wrap((model) -> { + fail("Expected exception, but got model: " + model); + }, e -> { + assertThat(e, instanceOf(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [amazonbedrock] service") + ); + }); + + service.parseRequestConfig("id", TaskType.COMPLETION, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_MovesModel() throws IOException { + try (var service = createAmazonBedrockService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + var secretSettings = (AmazonBedrockSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey.toString(), is("access")); + assertThat(secretSettings.secretKey.toString(), is("secret")); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, null, null, null), + Map.of(), + getAmazonBedrockSecretSettingsMap("access", "secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testCreateModel_ForEmbeddingsTask_DimensionsIsNotAllowed() throws IOException { + try (var service = createAmazonBedrockService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ValidationException.class)); + assertThat(exception.getMessage(), containsString("[service_settings] does not allow the setting [dimensions]")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", 512, null, null, null), + Map.of(), + getAmazonBedrockSecretSettingsMap("access", "secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParsePersistedConfigWithSecrets_CreatesAnAmazonBedrockEmbeddingsModel() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + var secretSettings = (AmazonBedrockSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey.toString(), is("access")); + assertThat(secretSettings.secretKey.toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createChatCompletionRequestSettingsMap("region", "model", "amazontitan"); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, Map.of(), secretSettingsMap); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfigWithSecrets( + "id", + TaskType.SPARSE_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [amazonbedrock] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + var secretSettings = (AmazonBedrockSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey.toString(), is("access")); + assertThat(secretSettings.secretKey.toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInSecretsSettings() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + secretSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + var secretSettings = (AmazonBedrockSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey.toString(), is("access")); + assertThat(secretSettings.secretKey.toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + persistedConfig.secrets().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + var secretSettings = (AmazonBedrockSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey.toString(), is("access")); + assertThat(secretSettings.secretKey.toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + settingsMap.put("extra_key", "value"); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + var secretSettings = (AmazonBedrockSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey.toString(), is("access")); + assertThat(secretSettings.secretKey.toString(), is("secret")); + } + } + + public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createChatCompletionRequestSettingsMap("region", "model", "anthropic"); + var taskSettingsMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.2, 128); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + taskSettingsMap.put("extra_key", "value"); + + var persistedConfig = getPersistedConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.COMPLETION, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(AmazonBedrockChatCompletionModel.class)); + + var settings = (AmazonBedrockChatCompletionServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.ANTHROPIC)); + var taskSettings = (AmazonBedrockChatCompletionTaskSettings) model.getTaskSettings(); + assertThat(taskSettings.temperature(), is(1.0)); + assertThat(taskSettings.topP(), is(0.5)); + assertThat(taskSettings.topK(), is(0.2)); + assertThat(taskSettings.maxNewTokens(), is(128)); + var secretSettings = (AmazonBedrockSecretSettings) model.getSecretSettings(); + assertThat(secretSettings.accessKey.toString(), is("access")); + assertThat(secretSettings.secretKey.toString(), is("secret")); + } + } + + public void testParsePersistedConfig_CreatesAnAmazonBedrockEmbeddingsModel() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + assertNull(model.getSecretSettings()); + } + } + + public void testParsePersistedConfig_CreatesAnAmazonBedrockChatCompletionModel() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createChatCompletionRequestSettingsMap("region", "model", "anthropic"); + var taskSettingsMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.2, 128); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); + var model = service.parsePersistedConfig("id", TaskType.COMPLETION, persistedConfig.config()); + + assertThat(model, instanceOf(AmazonBedrockChatCompletionModel.class)); + + var settings = (AmazonBedrockChatCompletionServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.ANTHROPIC)); + var taskSettings = (AmazonBedrockChatCompletionTaskSettings) model.getTaskSettings(); + assertThat(taskSettings.temperature(), is(1.0)); + assertThat(taskSettings.topP(), is(0.5)); + assertThat(taskSettings.topK(), is(0.2)); + assertThat(taskSettings.maxNewTokens(), is(128)); + assertNull(model.getSecretSettings()); + } + } + + public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfig("id", TaskType.SPARSE_EMBEDDING, persistedConfig.config()) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [amazonbedrock] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + assertNull(model.getSecretSettings()); + } + } + + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createEmbeddingsRequestSettingsMap("region", "model", "amazontitan", null, false, null, null); + settingsMap.put("extra_key", "value"); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, new HashMap(Map.of()), secretSettingsMap); + persistedConfig.config().put("extra_key", "value"); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); + + assertThat(model, instanceOf(AmazonBedrockEmbeddingsModel.class)); + + var settings = (AmazonBedrockEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.AMAZONTITAN)); + assertNull(model.getSecretSettings()); + } + } + + public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings() throws IOException { + try (var service = createAmazonBedrockService()) { + var settingsMap = createChatCompletionRequestSettingsMap("region", "model", "anthropic"); + var taskSettingsMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.2, 128); + taskSettingsMap.put("extra_key", "value"); + var secretSettingsMap = getAmazonBedrockSecretSettingsMap("access", "secret"); + + var persistedConfig = getPersistedConfigMap(settingsMap, taskSettingsMap, secretSettingsMap); + var model = service.parsePersistedConfig("id", TaskType.COMPLETION, persistedConfig.config()); + + assertThat(model, instanceOf(AmazonBedrockChatCompletionModel.class)); + + var settings = (AmazonBedrockChatCompletionServiceSettings) model.getServiceSettings(); + assertThat(settings.region(), is("region")); + assertThat(settings.model(), is("model")); + assertThat(settings.provider(), is(AmazonBedrockProvider.ANTHROPIC)); + var taskSettings = (AmazonBedrockChatCompletionTaskSettings) model.getTaskSettings(); + assertThat(taskSettings.temperature(), is(1.0)); + assertThat(taskSettings.topP(), is(0.5)); + assertThat(taskSettings.topK(), is(0.2)); + assertThat(taskSettings.maxNewTokens(), is(128)); + assertNull(model.getSecretSettings()); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotAmazonBedrockModel() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + var mockModel = getInvalidModel("model_id", "service_name"); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + List.of(""), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(); + verify(sender, times(1)).start(); + } + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testInfer_SendsRequest_ForEmbeddingsModel() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { + var results = new InferenceTextEmbeddingFloatResults( + List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F })) + ); + requestSender.enqueue(results); + + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "access", + "secret" + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), Matchers.is(buildExpectationFloat(List.of(new float[] { 0.123F, 0.678F })))); + } + } + } + + public void testInfer_SendsRequest_ForChatCompletionModel() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { + var mockResults = new ChatCompletionResults(List.of(new ChatCompletionResults.Result("test result"))); + requestSender.enqueue(mockResults); + + var model = AmazonBedrockChatCompletionModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "access", + "secret" + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), Matchers.is(buildExpectationCompletion(List.of("test result")))); + } + } + } + + public void testCheckModelConfig_IncludesMaxTokens_ForEmbeddingsModel() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { + var results = new InferenceTextEmbeddingFloatResults( + List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F })) + ); + requestSender.enqueue(results); + + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + false, + 100, + null, + null, + "access", + "secret" + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + assertThat( + result, + is( + AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + 2, + false, + 100, + SimilarityMeasure.COSINE, + null, + "access", + "secret" + ) + ) + ); + var inputStrings = requestSender.getInputs(); + + MatcherAssert.assertThat(inputStrings, Matchers.is(List.of("how big"))); + } + } + } + + public void testCheckModelConfig_HasSimilarity_ForEmbeddingsModel() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { + var results = new InferenceTextEmbeddingFloatResults( + List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F })) + ); + requestSender.enqueue(results); + + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + false, + null, + SimilarityMeasure.COSINE, + null, + "access", + "secret" + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + assertThat( + result, + is( + AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + 2, + false, + null, + SimilarityMeasure.COSINE, + null, + "access", + "secret" + ) + ) + ); + var inputStrings = requestSender.getInputs(); + + MatcherAssert.assertThat(inputStrings, Matchers.is(List.of("how big"))); + } + } + } + + public void testCheckModelConfig_ThrowsIfEmbeddingSizeDoesNotMatchValueSetByUser() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { + var results = new InferenceTextEmbeddingFloatResults( + List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F })) + ); + requestSender.enqueue(results); + + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + 3, + true, + null, + null, + null, + "access", + "secret" + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + + var exception = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + exception.getMessage(), + is( + "The retrieved embeddings size [2] does not match the size specified in the settings [3]. " + + "Please recreate the [id] configuration with the correct dimensions" + ) + ); + + var inputStrings = requestSender.getInputs(); + MatcherAssert.assertThat(inputStrings, Matchers.is(List.of("how big"))); + } + } + } + + public void testCheckModelConfig_ReturnsNewModelReference_AndDoesNotSendDimensionsField_WhenNotSetByUser() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { + var results = new InferenceTextEmbeddingFloatResults( + List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F })) + ); + requestSender.enqueue(results); + + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + 100, + false, + null, + SimilarityMeasure.COSINE, + null, + "access", + "secret" + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + assertThat( + result, + is( + AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + 2, + false, + null, + SimilarityMeasure.COSINE, + null, + "access", + "secret" + ) + ) + ); + var inputStrings = requestSender.getInputs(); + + MatcherAssert.assertThat(inputStrings, Matchers.is(List.of("how big"))); + } + } + } + + public void testInfer_UnauthorizedResponse() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "us-east-1", + "amazon.titan-embed-text-v1", + AmazonBedrockProvider.AMAZONTITAN, + "_INVALID_AWS_ACCESS_KEY_", + "_INVALID_AWS_SECRET_KEY_" + ); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var exceptionThrown = assertThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(exceptionThrown.getCause().getMessage(), containsString("The security token included in the request is invalid")); + } + } + + public void testChunkedInfer_CallsInfer_ConvertsFloatResponse_ForEmbeddings() throws IOException { + var sender = mock(Sender.class); + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender()).thenReturn(sender); + + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + + try (var service = new AmazonBedrockService(factory, amazonBedrockFactory, createWithEmptySettings(threadPool))) { + try (var requestSender = (AmazonBedrockMockRequestSender) amazonBedrockFactory.createSender()) { + { + var mockResults1 = new InferenceTextEmbeddingFloatResults( + List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.123F, 0.678F })) + ); + requestSender.enqueue(mockResults1); + } + { + var mockResults2 = new InferenceTextEmbeddingFloatResults( + List.of(new InferenceTextEmbeddingFloatResults.InferenceFloatEmbedding(new float[] { 0.223F, 0.278F })) + ); + requestSender.enqueue(mockResults2); + } + + var model = AmazonBedrockEmbeddingsModelTests.createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + "access", + "secret" + ); + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + List.of("abc", "xyz"), + new HashMap<>(), + InputType.INGEST, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, hasSize(2)); + { + assertThat(results.get(0), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); + var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals("abc", floatResult.chunks().get(0).matchedText()); + assertArrayEquals(new float[] { 0.123F, 0.678F }, floatResult.chunks().get(0).embedding(), 0.0f); + } + { + assertThat(results.get(1), CoreMatchers.instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); + var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals("xyz", floatResult.chunks().get(0).matchedText()); + assertArrayEquals(new float[] { 0.223F, 0.278F }, floatResult.chunks().get(0).embedding(), 0.0f); + } + } + } + } + + private AmazonBedrockService createAmazonBedrockService() { + var amazonBedrockFactory = new AmazonBedrockMockRequestSender.Factory( + ServiceComponentsTests.createWithSettings(threadPool, Settings.EMPTY), + mockClusterServiceEmpty() + ); + return new AmazonBedrockService(mock(HttpRequestSender.Factory.class), amazonBedrockFactory, createWithEmptySettings(threadPool)); + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>( + Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) + ); + } + + private Utils.PersistedConfig getPersistedConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + + return new Utils.PersistedConfig( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + new HashMap<>(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings)) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModelTests.java new file mode 100644 index 0000000000000..22173943ff432 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionModelTests.java @@ -0,0 +1,221 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.completion; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import static org.elasticsearch.xpack.inference.services.amazonbedrock.completion.AmazonBedrockChatCompletionTaskSettingsTests.getChatCompletionTaskSettingsMap; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; + +public class AmazonBedrockChatCompletionModelTests extends ESTestCase { + public void testOverrideWith_OverridesWithoutValues() { + var model = createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + 1.0, + 0.5, + 0.6, + 512, + null, + "access_key", + "secret_key" + ); + var requestTaskSettingsMap = getChatCompletionTaskSettingsMap(null, null, null, null); + var overriddenModel = AmazonBedrockChatCompletionModel.of(model, requestTaskSettingsMap); + + assertThat(overriddenModel, sameInstance(overriddenModel)); + } + + public void testOverrideWith_temperature() { + var model = createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + 1.0, + null, + null, + null, + null, + "access_key", + "secret_key" + ); + var requestTaskSettings = getChatCompletionTaskSettingsMap(0.5, null, null, null); + var overriddenModel = AmazonBedrockChatCompletionModel.of(model, requestTaskSettings); + assertThat( + overriddenModel, + is( + createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + 0.5, + null, + null, + null, + null, + "access_key", + "secret_key" + ) + ) + ); + } + + public void testOverrideWith_topP() { + var model = createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + 0.8, + null, + null, + null, + "access_key", + "secret_key" + ); + var requestTaskSettings = getChatCompletionTaskSettingsMap(null, 0.5, null, null); + var overriddenModel = AmazonBedrockChatCompletionModel.of(model, requestTaskSettings); + assertThat( + overriddenModel, + is( + createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + 0.5, + null, + null, + null, + "access_key", + "secret_key" + ) + ) + ); + } + + public void testOverrideWith_topK() { + var model = createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + null, + 1.0, + null, + null, + "access_key", + "secret_key" + ); + var requestTaskSettings = getChatCompletionTaskSettingsMap(null, null, 0.8, null); + var overriddenModel = AmazonBedrockChatCompletionModel.of(model, requestTaskSettings); + assertThat( + overriddenModel, + is( + createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + null, + 0.8, + null, + null, + "access_key", + "secret_key" + ) + ) + ); + } + + public void testOverrideWith_maxNewTokens() { + var model = createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + null, + null, + 512, + null, + "access_key", + "secret_key" + ); + var requestTaskSettings = getChatCompletionTaskSettingsMap(null, null, null, 128); + var overriddenModel = AmazonBedrockChatCompletionModel.of(model, requestTaskSettings); + assertThat( + overriddenModel, + is( + createModel( + "id", + "region", + "model", + AmazonBedrockProvider.AMAZONTITAN, + null, + null, + null, + 128, + null, + "access_key", + "secret_key" + ) + ) + ); + } + + public static AmazonBedrockChatCompletionModel createModel( + String id, + String region, + String model, + AmazonBedrockProvider provider, + String accessKey, + String secretKey + ) { + return createModel(id, region, model, provider, null, null, null, null, null, accessKey, secretKey); + } + + public static AmazonBedrockChatCompletionModel createModel( + String id, + String region, + String model, + AmazonBedrockProvider provider, + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Double topK, + @Nullable Integer maxNewTokens, + @Nullable RateLimitSettings rateLimitSettings, + String accessKey, + String secretKey + ) { + return new AmazonBedrockChatCompletionModel( + id, + TaskType.COMPLETION, + "amazonbedrock", + new AmazonBedrockChatCompletionServiceSettings(region, model, provider, rateLimitSettings), + new AmazonBedrockChatCompletionTaskSettings(temperature, topP, topK, maxNewTokens), + new AmazonBedrockSecretSettings(new SecureString(accessKey), new SecureString(secretKey)) + ); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionRequestTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionRequestTaskSettingsTests.java new file mode 100644 index 0000000000000..681088c786b6b --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionRequestTaskSettingsTests.java @@ -0,0 +1,107 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.completion; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.MatcherAssert; + +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MAX_NEW_TOKENS_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TEMPERATURE_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_K_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_P_FIELD; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockChatCompletionRequestTaskSettingsTests extends ESTestCase { + public void testFromMap_ReturnsEmptySettings_WhenTheMapIsEmpty() { + var settings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of())); + assertThat(settings, is(AmazonBedrockChatCompletionRequestTaskSettings.EMPTY_SETTINGS)); + } + + public void testFromMap_ReturnsEmptySettings_WhenTheMapDoesNotContainTheFields() { + var settings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of("key", "model"))); + assertThat(settings, is(AmazonBedrockChatCompletionRequestTaskSettings.EMPTY_SETTINGS)); + } + + public void testFromMap_ReturnsTemperature() { + var settings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(TEMPERATURE_FIELD, 0.1))); + assertThat(settings.temperature(), is(0.1)); + } + + public void testFromMap_ReturnsTopP() { + var settings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(TOP_P_FIELD, 0.1))); + assertThat(settings.topP(), is(0.1)); + } + + public void testFromMap_ReturnsDoSample() { + var settings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(TOP_K_FIELD, 0.3))); + assertThat(settings.topK(), is(0.3)); + } + + public void testFromMap_ReturnsMaxNewTokens() { + var settings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(MAX_NEW_TOKENS_FIELD, 512))); + assertThat(settings.maxNewTokens(), is(512)); + } + + public void testFromMap_TemperatureIsInvalidValue_ThrowsValidationException() { + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(TEMPERATURE_FIELD, "invalid"))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format("field [temperature] is not of the expected type. The value [invalid] cannot be converted to a [Double]") + ) + ); + } + + public void testFromMap_TopPIsInvalidValue_ThrowsValidationException() { + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(TOP_P_FIELD, "invalid"))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format("field [top_p] is not of the expected type. The value [invalid] cannot be converted to a [Double]") + ) + ); + } + + public void testFromMap_TopKIsInvalidValue_ThrowsValidationException() { + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(TOP_K_FIELD, "invalid"))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString("field [top_k] is not of the expected type. The value [invalid] cannot be converted to a [Double]") + ); + } + + public void testFromMap_MaxTokensIsInvalidValue_ThrowsStatusException() { + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockChatCompletionRequestTaskSettings.fromMap(new HashMap<>(Map.of(MAX_NEW_TOKENS_FIELD, "invalid"))) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString("field [max_new_tokens] is not of the expected type. The value [invalid] cannot be converted to a [Integer]") + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionServiceSettingsTests.java new file mode 100644 index 0000000000000..90868530d8df8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionServiceSettingsTests.java @@ -0,0 +1,131 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockChatCompletionServiceSettingsTests extends AbstractBWCWireSerializationTestCase< + AmazonBedrockChatCompletionServiceSettings> { + + public void testFromMap_Request_CreatesSettingsCorrectly() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var serviceSettings = AmazonBedrockChatCompletionServiceSettings.fromMap( + createChatCompletionRequestSettingsMap(region, model, provider), + ConfigurationParseContext.REQUEST + ); + + assertThat( + serviceSettings, + is(new AmazonBedrockChatCompletionServiceSettings(region, model, AmazonBedrockProvider.AMAZONTITAN, null)) + ); + } + + public void testFromMap_RequestWithRateLimit_CreatesSettingsCorrectly() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var settingsMap = createChatCompletionRequestSettingsMap(region, model, provider); + settingsMap.put(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3))); + + var serviceSettings = AmazonBedrockChatCompletionServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST); + + assertThat( + serviceSettings, + is(new AmazonBedrockChatCompletionServiceSettings(region, model, AmazonBedrockProvider.AMAZONTITAN, new RateLimitSettings(3))) + ); + } + + public void testFromMap_Persistent_CreatesSettingsCorrectly() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var settingsMap = createChatCompletionRequestSettingsMap(region, model, provider); + var serviceSettings = AmazonBedrockChatCompletionServiceSettings.fromMap(settingsMap, ConfigurationParseContext.PERSISTENT); + + assertThat( + serviceSettings, + is(new AmazonBedrockChatCompletionServiceSettings(region, model, AmazonBedrockProvider.AMAZONTITAN, null)) + ); + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new AmazonBedrockChatCompletionServiceSettings( + "testregion", + "testmodel", + AmazonBedrockProvider.AMAZONTITAN, + new RateLimitSettings(3) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"region":"testregion","model":"testmodel","provider":"AMAZONTITAN",""" + """ + "rate_limit":{"requests_per_minute":3}}""")); + } + + public static HashMap createChatCompletionRequestSettingsMap(String region, String model, String provider) { + return new HashMap(Map.of(REGION_FIELD, region, MODEL_FIELD, model, PROVIDER_FIELD, provider)); + } + + @Override + protected AmazonBedrockChatCompletionServiceSettings mutateInstanceForVersion( + AmazonBedrockChatCompletionServiceSettings instance, + TransportVersion version + ) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return AmazonBedrockChatCompletionServiceSettings::new; + } + + @Override + protected AmazonBedrockChatCompletionServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AmazonBedrockChatCompletionServiceSettings mutateInstance(AmazonBedrockChatCompletionServiceSettings instance) + throws IOException { + return randomValueOtherThan(instance, AmazonBedrockChatCompletionServiceSettingsTests::createRandom); + } + + private static AmazonBedrockChatCompletionServiceSettings createRandom() { + return new AmazonBedrockChatCompletionServiceSettings( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomFrom(AmazonBedrockProvider.values()), + RateLimitSettingsTests.createRandom() + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettingsTests.java new file mode 100644 index 0000000000000..0d5440c6d2cf8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/completion/AmazonBedrockChatCompletionTaskSettingsTests.java @@ -0,0 +1,226 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.completion; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MAX_NEW_TOKENS_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TEMPERATURE_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_K_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.TOP_P_FIELD; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockChatCompletionTaskSettingsTests extends AbstractBWCWireSerializationTestCase< + AmazonBedrockChatCompletionTaskSettings> { + + public void testFromMap_AllValues() { + var taskMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512); + assertEquals( + new AmazonBedrockChatCompletionTaskSettings(1.0, 0.5, 0.6, 512), + AmazonBedrockChatCompletionTaskSettings.fromMap(taskMap) + ); + } + + public void testFromMap_TemperatureIsInvalidValue_ThrowsValidationException() { + var taskMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512); + taskMap.put(TEMPERATURE_FIELD, "invalid"); + + var thrownException = expectThrows(ValidationException.class, () -> AmazonBedrockChatCompletionTaskSettings.fromMap(taskMap)); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format("field [temperature] is not of the expected type. The value [invalid] cannot be converted to a [Double]") + ) + ); + } + + public void testFromMap_TopPIsInvalidValue_ThrowsValidationException() { + var taskMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512); + taskMap.put(TOP_P_FIELD, "invalid"); + + var thrownException = expectThrows(ValidationException.class, () -> AmazonBedrockChatCompletionTaskSettings.fromMap(taskMap)); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format("field [top_p] is not of the expected type. The value [invalid] cannot be converted to a [Double]") + ) + ); + } + + public void testFromMap_TopKIsInvalidValue_ThrowsValidationException() { + var taskMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512); + taskMap.put(TOP_K_FIELD, "invalid"); + + var thrownException = expectThrows(ValidationException.class, () -> AmazonBedrockChatCompletionTaskSettings.fromMap(taskMap)); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString("field [top_k] is not of the expected type. The value [invalid] cannot be converted to a [Double]") + ); + } + + public void testFromMap_MaxNewTokensIsInvalidValue_ThrowsValidationException() { + var taskMap = getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512); + taskMap.put(MAX_NEW_TOKENS_FIELD, "invalid"); + + var thrownException = expectThrows(ValidationException.class, () -> AmazonBedrockChatCompletionTaskSettings.fromMap(taskMap)); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format("field [max_new_tokens] is not of the expected type. The value [invalid] cannot be converted to a [Integer]") + ) + ); + } + + public void testFromMap_WithNoValues_DoesNotThrowException() { + var taskMap = AmazonBedrockChatCompletionTaskSettings.fromMap(new HashMap(Map.of())); + assertNull(taskMap.temperature()); + assertNull(taskMap.topP()); + assertNull(taskMap.topK()); + assertNull(taskMap.maxNewTokens()); + } + + public void testOverrideWith_KeepsOriginalValuesWithOverridesAreNull() { + var settings = AmazonBedrockChatCompletionTaskSettings.fromMap(getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512)); + var overrideSettings = AmazonBedrockChatCompletionTaskSettings.of(settings, AmazonBedrockChatCompletionTaskSettings.EMPTY_SETTINGS); + MatcherAssert.assertThat(overrideSettings, is(settings)); + } + + public void testOverrideWith_UsesTemperatureOverride() { + var settings = AmazonBedrockChatCompletionTaskSettings.fromMap(getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512)); + var overrideSettings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap( + getChatCompletionTaskSettingsMap(0.3, null, null, null) + ); + var overriddenTaskSettings = AmazonBedrockChatCompletionTaskSettings.of(settings, overrideSettings); + MatcherAssert.assertThat(overriddenTaskSettings, is(new AmazonBedrockChatCompletionTaskSettings(0.3, 0.5, 0.6, 512))); + } + + public void testOverrideWith_UsesTopPOverride() { + var settings = AmazonBedrockChatCompletionTaskSettings.fromMap(getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512)); + var overrideSettings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap( + getChatCompletionTaskSettingsMap(null, 0.2, null, null) + ); + var overriddenTaskSettings = AmazonBedrockChatCompletionTaskSettings.of(settings, overrideSettings); + MatcherAssert.assertThat(overriddenTaskSettings, is(new AmazonBedrockChatCompletionTaskSettings(1.0, 0.2, 0.6, 512))); + } + + public void testOverrideWith_UsesDoSampleOverride() { + var settings = AmazonBedrockChatCompletionTaskSettings.fromMap(getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512)); + var overrideSettings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap( + getChatCompletionTaskSettingsMap(null, null, 0.1, null) + ); + var overriddenTaskSettings = AmazonBedrockChatCompletionTaskSettings.of(settings, overrideSettings); + MatcherAssert.assertThat(overriddenTaskSettings, is(new AmazonBedrockChatCompletionTaskSettings(1.0, 0.5, 0.1, 512))); + } + + public void testOverrideWith_UsesMaxNewTokensOverride() { + var settings = AmazonBedrockChatCompletionTaskSettings.fromMap(getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512)); + var overrideSettings = AmazonBedrockChatCompletionRequestTaskSettings.fromMap( + getChatCompletionTaskSettingsMap(null, null, null, 128) + ); + var overriddenTaskSettings = AmazonBedrockChatCompletionTaskSettings.of(settings, overrideSettings); + MatcherAssert.assertThat(overriddenTaskSettings, is(new AmazonBedrockChatCompletionTaskSettings(1.0, 0.5, 0.6, 128))); + } + + public void testToXContent_WithoutParameters() throws IOException { + var settings = AmazonBedrockChatCompletionTaskSettings.fromMap(getChatCompletionTaskSettingsMap(null, null, null, null)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + settings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is("{}")); + } + + public void testToXContent_WithParameters() throws IOException { + var settings = AmazonBedrockChatCompletionTaskSettings.fromMap(getChatCompletionTaskSettingsMap(1.0, 0.5, 0.6, 512)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + settings.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"temperature":1.0,"top_p":0.5,"top_k":0.6,"max_new_tokens":512}""")); + } + + public static Map getChatCompletionTaskSettingsMap( + @Nullable Double temperature, + @Nullable Double topP, + @Nullable Double topK, + @Nullable Integer maxNewTokens + ) { + var map = new HashMap(); + + if (temperature != null) { + map.put(TEMPERATURE_FIELD, temperature); + } + + if (topP != null) { + map.put(TOP_P_FIELD, topP); + } + + if (topK != null) { + map.put(TOP_K_FIELD, topK); + } + + if (maxNewTokens != null) { + map.put(MAX_NEW_TOKENS_FIELD, maxNewTokens); + } + + return map; + } + + @Override + protected AmazonBedrockChatCompletionTaskSettings mutateInstanceForVersion( + AmazonBedrockChatCompletionTaskSettings instance, + TransportVersion version + ) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return AmazonBedrockChatCompletionTaskSettings::new; + } + + @Override + protected AmazonBedrockChatCompletionTaskSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AmazonBedrockChatCompletionTaskSettings mutateInstance(AmazonBedrockChatCompletionTaskSettings instance) throws IOException { + return randomValueOtherThan(instance, AmazonBedrockChatCompletionTaskSettingsTests::createRandom); + } + + private static AmazonBedrockChatCompletionTaskSettings createRandom() { + return new AmazonBedrockChatCompletionTaskSettings( + randomFrom(new Double[] { null, randomDouble() }), + randomFrom(new Double[] { null, randomDouble() }), + randomFrom(new Double[] { null, randomDouble() }), + randomFrom(new Integer[] { null, randomNonNegativeInt() }) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModelTests.java new file mode 100644 index 0000000000000..711e3cbb5a511 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsModelTests.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; + +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; + +public class AmazonBedrockEmbeddingsModelTests extends ESTestCase { + + public void testCreateModel_withTaskSettings_shouldFail() { + var baseModel = createModel("id", "region", "model", AmazonBedrockProvider.AMAZONTITAN, "accesskey", "secretkey"); + var thrownException = assertThrows( + ValidationException.class, + () -> AmazonBedrockEmbeddingsModel.of(baseModel, Map.of("testkey", "testvalue")) + ); + assertThat(thrownException.getMessage(), containsString("Amazon Bedrock embeddings model cannot have task settings")); + } + + // model creation only - no tests to define, but we want to have the public createModel + // method available + + public static AmazonBedrockEmbeddingsModel createModel( + String inferenceId, + String region, + String model, + AmazonBedrockProvider provider, + String accessKey, + String secretKey + ) { + return createModel(inferenceId, region, model, provider, null, false, null, null, new RateLimitSettings(240), accessKey, secretKey); + } + + public static AmazonBedrockEmbeddingsModel createModel( + String inferenceId, + String region, + String model, + AmazonBedrockProvider provider, + @Nullable Integer dimensions, + boolean dimensionsSetByUser, + @Nullable Integer maxTokens, + @Nullable SimilarityMeasure similarity, + RateLimitSettings rateLimitSettings, + String accessKey, + String secretKey + ) { + return new AmazonBedrockEmbeddingsModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + "amazonbedrock", + new AmazonBedrockEmbeddingsServiceSettings( + region, + model, + provider, + dimensions, + dimensionsSetByUser, + maxTokens, + similarity, + rateLimitSettings + ), + new EmptyTaskSettings(), + new AmazonBedrockSecretSettings(new SecureString(accessKey), new SecureString(secretKey)) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..a100b89e1db6e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/embeddings/AmazonBedrockEmbeddingsServiceSettingsTests.java @@ -0,0 +1,404 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD; +import static org.elasticsearch.xpack.inference.services.amazonbedrock.embeddings.AmazonBedrockEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; + +public class AmazonBedrockEmbeddingsServiceSettingsTests extends AbstractBWCWireSerializationTestCase< + AmazonBedrockEmbeddingsServiceSettings> { + + public void testFromMap_Request_CreatesSettingsCorrectly() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var maxInputTokens = 512; + var serviceSettings = AmazonBedrockEmbeddingsServiceSettings.fromMap( + createEmbeddingsRequestSettingsMap(region, model, provider, null, null, maxInputTokens, SimilarityMeasure.COSINE), + ConfigurationParseContext.REQUEST + ); + + assertThat( + serviceSettings, + is( + new AmazonBedrockEmbeddingsServiceSettings( + region, + model, + AmazonBedrockProvider.AMAZONTITAN, + null, + false, + maxInputTokens, + SimilarityMeasure.COSINE, + null + ) + ) + ); + } + + public void testFromMap_RequestWithRateLimit_CreatesSettingsCorrectly() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var maxInputTokens = 512; + var settingsMap = createEmbeddingsRequestSettingsMap(region, model, provider, null, null, maxInputTokens, SimilarityMeasure.COSINE); + settingsMap.put(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3))); + + var serviceSettings = AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST); + + assertThat( + serviceSettings, + is( + new AmazonBedrockEmbeddingsServiceSettings( + region, + model, + AmazonBedrockProvider.AMAZONTITAN, + null, + false, + maxInputTokens, + SimilarityMeasure.COSINE, + new RateLimitSettings(3) + ) + ) + ); + } + + public void testFromMap_Request_DimensionsSetByUser_IsFalse_WhenDimensionsAreNotPresent() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var maxInputTokens = 512; + var settingsMap = createEmbeddingsRequestSettingsMap(region, model, provider, null, null, maxInputTokens, SimilarityMeasure.COSINE); + var serviceSettings = AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST); + + assertThat( + serviceSettings, + is( + new AmazonBedrockEmbeddingsServiceSettings( + region, + model, + AmazonBedrockProvider.AMAZONTITAN, + null, + false, + maxInputTokens, + SimilarityMeasure.COSINE, + null + ) + ) + ); + } + + public void testFromMap_Request_DimensionsSetByUser_ShouldThrowWhenPresent() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var maxInputTokens = 512; + + var settingsMap = createEmbeddingsRequestSettingsMap(region, model, provider, null, true, maxInputTokens, SimilarityMeasure.COSINE); + + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString( + Strings.format("Validation Failed: 1: [service_settings] does not allow the setting [%s];", DIMENSIONS_SET_BY_USER) + ) + ); + } + + public void testFromMap_Request_Dimensions_ShouldThrowWhenPresent() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var dims = 128; + + var settingsMap = createEmbeddingsRequestSettingsMap(region, model, provider, dims, null, null, null); + + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString(Strings.format("[service_settings] does not allow the setting [%s]", DIMENSIONS)) + ); + } + + public void testFromMap_Request_MaxTokensShouldBePositiveInteger() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var maxInputTokens = -128; + + var settingsMap = createEmbeddingsRequestSettingsMap(region, model, provider, null, null, maxInputTokens, null); + + var thrownException = expectThrows( + ValidationException.class, + () -> AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + MatcherAssert.assertThat( + thrownException.getMessage(), + containsString(Strings.format("[%s] must be a positive integer", MAX_INPUT_TOKENS)) + ); + } + + public void testFromMap_Persistent_CreatesSettingsCorrectly() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + var dims = 1536; + var maxInputTokens = 512; + + var settingsMap = createEmbeddingsRequestSettingsMap( + region, + model, + provider, + dims, + false, + maxInputTokens, + SimilarityMeasure.COSINE + ); + var serviceSettings = AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.PERSISTENT); + + assertThat( + serviceSettings, + is( + new AmazonBedrockEmbeddingsServiceSettings( + region, + model, + AmazonBedrockProvider.AMAZONTITAN, + dims, + false, + maxInputTokens, + SimilarityMeasure.COSINE, + null + ) + ) + ); + } + + public void testFromMap_PersistentContext_DoesNotThrowException_WhenDimensionsIsNull() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + + var settingsMap = createEmbeddingsRequestSettingsMap(region, model, provider, null, true, null, null); + var serviceSettings = AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.PERSISTENT); + + assertThat( + serviceSettings, + is(new AmazonBedrockEmbeddingsServiceSettings(region, model, AmazonBedrockProvider.AMAZONTITAN, null, true, null, null, null)) + ); + } + + public void testFromMap_PersistentContext_DoesNotThrowException_WhenSimilarityIsPresent() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + + var settingsMap = createEmbeddingsRequestSettingsMap(region, model, provider, null, true, null, SimilarityMeasure.DOT_PRODUCT); + var serviceSettings = AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.PERSISTENT); + + assertThat( + serviceSettings, + is( + new AmazonBedrockEmbeddingsServiceSettings( + region, + model, + AmazonBedrockProvider.AMAZONTITAN, + null, + true, + null, + SimilarityMeasure.DOT_PRODUCT, + null + ) + ) + ); + } + + public void testFromMap_PersistentContext_ThrowsException_WhenDimensionsSetByUserIsNull() { + var region = "region"; + var model = "model-id"; + var provider = "amazontitan"; + + var settingsMap = createEmbeddingsRequestSettingsMap(region, model, provider, 1, null, null, null); + + var exception = expectThrows( + ValidationException.class, + () -> AmazonBedrockEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.PERSISTENT) + ); + + assertThat( + exception.getMessage(), + containsString("Validation Failed: 1: [service_settings] does not contain the required setting [dimensions_set_by_user];") + ); + } + + public void testToXContent_WritesDimensionsSetByUserTrue() throws IOException { + var entity = new AmazonBedrockEmbeddingsServiceSettings( + "testregion", + "testmodel", + AmazonBedrockProvider.AMAZONTITAN, + null, + true, + null, + null, + new RateLimitSettings(2) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"region":"testregion","model":"testmodel","provider":"AMAZONTITAN",""" + """ + "rate_limit":{"requests_per_minute":2},"dimensions_set_by_user":true}""")); + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new AmazonBedrockEmbeddingsServiceSettings( + "testregion", + "testmodel", + AmazonBedrockProvider.AMAZONTITAN, + 1024, + false, + 512, + null, + new RateLimitSettings(3) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"region":"testregion","model":"testmodel","provider":"AMAZONTITAN",""" + """ + "rate_limit":{"requests_per_minute":3},"dimensions":1024,"max_input_tokens":512,"dimensions_set_by_user":false}""")); + } + + public void testToFilteredXContent_WritesAllValues_ExceptDimensionsSetByUser() throws IOException { + var entity = new AmazonBedrockEmbeddingsServiceSettings( + "testregion", + "testmodel", + AmazonBedrockProvider.AMAZONTITAN, + 1024, + false, + 512, + null, + new RateLimitSettings(3) + ); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + var filteredXContent = entity.getFilteredXContentObject(); + filteredXContent.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"region":"testregion","model":"testmodel","provider":"AMAZONTITAN",""" + """ + "rate_limit":{"requests_per_minute":3},"dimensions":1024,"max_input_tokens":512}""")); + } + + public static HashMap createEmbeddingsRequestSettingsMap( + String region, + String model, + String provider, + @Nullable Integer dimensions, + @Nullable Boolean dimensionsSetByUser, + @Nullable Integer maxTokens, + @Nullable SimilarityMeasure similarityMeasure + ) { + var map = new HashMap(Map.of(REGION_FIELD, region, MODEL_FIELD, model, PROVIDER_FIELD, provider)); + + if (dimensions != null) { + map.put(ServiceFields.DIMENSIONS, dimensions); + } + + if (dimensionsSetByUser != null) { + map.put(DIMENSIONS_SET_BY_USER, dimensionsSetByUser.equals(Boolean.TRUE)); + } + + if (maxTokens != null) { + map.put(ServiceFields.MAX_INPUT_TOKENS, maxTokens); + } + + if (similarityMeasure != null) { + map.put(SIMILARITY, similarityMeasure.toString()); + } + + return map; + } + + @Override + protected AmazonBedrockEmbeddingsServiceSettings mutateInstanceForVersion( + AmazonBedrockEmbeddingsServiceSettings instance, + TransportVersion version + ) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return AmazonBedrockEmbeddingsServiceSettings::new; + } + + @Override + protected AmazonBedrockEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected AmazonBedrockEmbeddingsServiceSettings mutateInstance(AmazonBedrockEmbeddingsServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, AmazonBedrockEmbeddingsServiceSettingsTests::createRandom); + } + + private static AmazonBedrockEmbeddingsServiceSettings createRandom() { + return new AmazonBedrockEmbeddingsServiceSettings( + randomAlphaOfLength(10), + randomAlphaOfLength(10), + randomFrom(AmazonBedrockProvider.values()), + randomFrom(new Integer[] { null, randomNonNegativeInt() }), + randomBoolean(), + randomFrom(new Integer[] { null, randomNonNegativeInt() }), + randomFrom(new SimilarityMeasure[] { null, randomFrom(SimilarityMeasure.values()) }), + RateLimitSettingsTests.createRandom() + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettingsTests.java index 05388192b2f14..c857a22e52996 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/embeddings/AzureAiStudioEmbeddingsServiceSettingsTests.java @@ -170,6 +170,92 @@ public void testFromMap_Persistent_CreatesSettingsCorrectly() { ); } + public void testFromMap_ThrowsException_WhenDimensionsAreZero() { + var target = "http://sometarget.local"; + var provider = "openai"; + var endpointType = "token"; + var dimensions = 0; + + var settingsMap = createRequestSettingsMap(target, provider, endpointType, dimensions, true, null, SimilarityMeasure.COSINE); + + var thrownException = expectThrows( + ValidationException.class, + () -> AzureAiStudioEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_ThrowsException_WhenDimensionsAreNegative() { + var target = "http://sometarget.local"; + var provider = "openai"; + var endpointType = "token"; + var dimensions = randomNegativeInt(); + + var settingsMap = createRequestSettingsMap(target, provider, endpointType, dimensions, true, null, SimilarityMeasure.COSINE); + + var thrownException = expectThrows( + ValidationException.class, + () -> AzureAiStudioEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [dimensions] must be a positive integer;", + dimensions + ) + ) + ); + } + + public void testFromMap_ThrowsException_WhenMaxInputTokensAreZero() { + var target = "http://sometarget.local"; + var provider = "openai"; + var endpointType = "token"; + var maxInputTokens = 0; + + var settingsMap = createRequestSettingsMap(target, provider, endpointType, null, true, maxInputTokens, SimilarityMeasure.COSINE); + + var thrownException = expectThrows( + ValidationException.class, + () -> AzureAiStudioEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_ThrowsException_WhenMaxInputTokensAreNegative() { + var target = "http://sometarget.local"; + var provider = "openai"; + var endpointType = "token"; + var maxInputTokens = randomNegativeInt(); + + var settingsMap = createRequestSettingsMap(target, provider, endpointType, null, true, maxInputTokens, SimilarityMeasure.COSINE); + + var thrownException = expectThrows( + ValidationException.class, + () -> AzureAiStudioEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [max_input_tokens] must be a positive integer;", + maxInputTokens + ) + ) + ); + } + public void testFromMap_PersistentContext_DoesNotThrowException_WhenDimensionsIsNull() { var target = "http://sometarget.local"; var provider = "openai"; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java index cbb9eea223802..8b754257e9d83 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/embeddings/AzureOpenAiEmbeddingsServiceSettingsTests.java @@ -203,6 +203,92 @@ public void testFromMap_Request_DimensionsSetByUser_ShouldThrowWhenPresent() { ); } + public void testFromMap_ThrowsException_WhenDimensionsAreZero() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + var dimensions = 0; + + var settingsMap = getRequestAzureOpenAiServiceSettingsMap(resourceName, deploymentId, apiVersion, dimensions, null); + + var thrownException = expectThrows( + ValidationException.class, + () -> AzureOpenAiEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_ThrowsException_WhenDimensionsAreNegative() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + var dimensions = randomNegativeInt(); + + var settingsMap = getRequestAzureOpenAiServiceSettingsMap(resourceName, deploymentId, apiVersion, dimensions, null); + + var thrownException = expectThrows( + ValidationException.class, + () -> AzureOpenAiEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [dimensions] must be a positive integer;", + dimensions + ) + ) + ); + } + + public void testFromMap_ThrowsException_WhenMaxInputTokensAreZero() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + var maxInputTokens = 0; + + var settingsMap = getRequestAzureOpenAiServiceSettingsMap(resourceName, deploymentId, apiVersion, null, maxInputTokens); + + var thrownException = expectThrows( + ValidationException.class, + () -> AzureOpenAiEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_ThrowsException_WhenMaxInputTokensAreNegative() { + var resourceName = "this-resource"; + var deploymentId = "this-deployment"; + var apiVersion = "2024-01-01"; + var maxInputTokens = randomNegativeInt(); + + var settingsMap = getRequestAzureOpenAiServiceSettingsMap(resourceName, deploymentId, apiVersion, null, maxInputTokens); + + var thrownException = expectThrows( + ValidationException.class, + () -> AzureOpenAiEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [max_input_tokens] must be a positive integer;", + maxInputTokens + ) + ) + ); + } + public void testFromMap_Persistent_CreatesSettingsCorrectly() { var resourceName = "this-resource"; var deploymentId = "this-deployment"; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettingsTests.java index 0cc3e6698388d..8e8a1db76da14 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/CustomElandInternalTextEmbeddingServiceSettingsTests.java @@ -47,6 +47,7 @@ public static CustomElandInternalTextEmbeddingServiceSettings createRandom() { numAllocations, numThreads, modelId, + null, dims, similarityMeasure, elementType @@ -84,6 +85,7 @@ public void testFromMap_Request_CreatesSettingsCorrectly() { numThreads, modelId, null, + null, SimilarityMeasure.DOT_PRODUCT, DenseVectorFieldMapper.ElementType.FLOAT ) @@ -108,6 +110,7 @@ public void testFromMap_Request_DoesNotDefaultSimilarityElementType() { numThreads, modelId, null, + null, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT ) @@ -148,6 +151,7 @@ public void testFromMap_Request_IgnoresDimensions() { numThreads, modelId, null, + null, SimilarityMeasure.DOT_PRODUCT, DenseVectorFieldMapper.ElementType.FLOAT ) @@ -187,6 +191,7 @@ public void testFromMap_Persistent_CreatesSettingsCorrectly() { numAllocations, numThreads, modelId, + null, 1, SimilarityMeasure.DOT_PRODUCT, DenseVectorFieldMapper.ElementType.FLOAT @@ -200,6 +205,7 @@ public void testToXContent_WritesAllValues() throws IOException { 1, 1, "model_id", + null, 100, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.BYTE @@ -210,7 +216,8 @@ public void testToXContent_WritesAllValues() throws IOException { String xContentResult = Strings.toString(builder); assertThat(xContentResult, is(""" - {"num_allocations":1,"num_threads":1,"model_id":"model_id","dimensions":100,"similarity":"cosine","element_type":"byte"}""")); + {"num_allocations":1,"num_threads":1,"model_id":"model_id","adaptive_allocations":null,"dimensions":100,""" + """ + "similarity":"cosine","element_type":"byte"}""")); } @Override diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java index 3bec202ed9e5e..ad1910cb9fc0a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalServiceTests.java @@ -134,7 +134,8 @@ public void testParseRequestConfig() { var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings( 1, 4, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID + ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID, + null ); service.parseRequestConfig( @@ -400,7 +401,7 @@ public void testParsePersistedConfig() { taskType, settings ); - var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(1, 4, "invalid"); + var elandServiceSettings = new CustomElandInternalTextEmbeddingServiceSettings(1, 4, "invalid", null); assertEquals( new CustomElandEmbeddingModel(randomInferenceEntityId, taskType, ElasticsearchInternalService.NAME, elandServiceSettings), parsedModel @@ -430,7 +431,8 @@ public void testParsePersistedConfig() { var e5ServiceSettings = new MultilingualE5SmallInternalServiceSettings( 1, 4, - ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID + ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_MODEL_ID, + null ); MultilingualE5SmallModel parsedModel = (MultilingualE5SmallModel) service.parsePersistedConfig( @@ -500,7 +502,7 @@ public void testChunkInfer() { "foo", TaskType.TEXT_EMBEDDING, "e5", - new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform") + new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null) ); var service = createService(client); @@ -594,7 +596,7 @@ public void testChunkInferSetsTokenization() { "foo", TaskType.TEXT_EMBEDDING, "e5", - new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform") + new MultilingualE5SmallInternalServiceSettings(1, 1, "cross-platform", null) ); var service = createService(client); @@ -726,11 +728,11 @@ private CustomElandModel getCustomElandModel(TaskType taskType) { randomInferenceEntityId, taskType, ElasticsearchInternalService.NAME, - new CustomElandInternalServiceSettings(1, 4, "custom-model"), + new CustomElandInternalServiceSettings(1, 4, "custom-model", null), CustomElandRerankTaskSettings.DEFAULT_SETTINGS ); } else if (taskType == TaskType.TEXT_EMBEDDING) { - var serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings(1, 4, "custom-model"); + var serviceSettings = new CustomElandInternalTextEmbeddingServiceSettings(1, 4, "custom-model", null); expectedModel = new CustomElandEmbeddingModel( randomInferenceEntityId, @@ -786,7 +788,7 @@ public void testPutModel() { "my-e5", TaskType.TEXT_EMBEDDING, "e5", - new MultilingualE5SmallInternalServiceSettings(1, 1, ".multilingual-e5-small") + new MultilingualE5SmallInternalServiceSettings(1, 1, ".multilingual-e5-small", null) ); service.putModel(model, new ActionListener<>() { @@ -827,6 +829,7 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() { 1, 4, "custom-model", + null, 1, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT @@ -850,6 +853,7 @@ public void testParseRequestConfigEland_SetsDimensionsToOne() { 4, "custom-model", null, + null, SimilarityMeasure.COSINE, DenseVectorFieldMapper.ElementType.FLOAT ) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettingsTests.java index fbff04efe6883..927d53360a2c5 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elasticsearch/MultilingualE5SmallInternalServiceSettingsTests.java @@ -24,7 +24,8 @@ public static MultilingualE5SmallInternalServiceSettings createRandom() { return new MultilingualE5SmallInternalServiceSettings( randomIntBetween(1, 4), randomIntBetween(1, 4), - randomFrom(ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_VALID_IDS) + randomFrom(ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_VALID_IDS), + null ); } @@ -56,7 +57,7 @@ public void testFromMap() { ) ) ).build(); - assertEquals(new MultilingualE5SmallInternalServiceSettings(1, 4, randomModelVariant), serviceSettings); + assertEquals(new MultilingualE5SmallInternalServiceSettings(1, 4, randomModelVariant, null), serviceSettings); } public void testFromMapInvalidVersion() { @@ -130,12 +131,14 @@ protected MultilingualE5SmallInternalServiceSettings mutateInstance(Multilingual case 0 -> new MultilingualE5SmallInternalServiceSettings( instance.getNumAllocations() + 1, instance.getNumThreads(), - instance.getModelId() + instance.getModelId(), + null ); case 1 -> new MultilingualE5SmallInternalServiceSettings( instance.getNumAllocations(), instance.getNumThreads() + 1, - instance.getModelId() + instance.getModelId(), + null ); case 2 -> { var versions = new HashSet<>(ElasticsearchInternalService.MULTILINGUAL_E5_SMALL_VALID_IDS); @@ -143,7 +146,8 @@ protected MultilingualE5SmallInternalServiceSettings mutateInstance(Multilingual yield new MultilingualE5SmallInternalServiceSettings( instance.getNumAllocations(), instance.getNumThreads(), - versions.iterator().next() + versions.iterator().next(), + null ); } default -> throw new IllegalStateException(); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java index c0e425144a618..e7fbbffa2d3fe 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceSettingsTests.java @@ -25,7 +25,8 @@ public static ElserInternalServiceSettings createRandom() { return new ElserInternalServiceSettings( randomIntBetween(1, 4), randomIntBetween(1, 2), - randomFrom(ElserInternalService.VALID_ELSER_MODEL_IDS) + randomFrom(ElserInternalService.VALID_ELSER_MODEL_IDS), + null ); } @@ -49,7 +50,7 @@ public void testFromMap() { ) ) ).build(); - assertEquals(new ElserInternalServiceSettings(1, 4, ".elser_model_1"), serviceSettings); + assertEquals(new ElserInternalServiceSettings(1, 4, ".elser_model_1", null), serviceSettings); } public void testFromMapInvalidVersion() { @@ -89,12 +90,12 @@ public void testFromMapMissingOptions() { public void testBwcWrite() throws IOException { { - var settings = new ElserInternalServiceSettings(1, 1, ".elser_model_1"); + var settings = new ElserInternalServiceSettings(1, 1, ".elser_model_1", null); var copy = copyInstance(settings, TransportVersions.V_8_12_0); assertEquals(settings, copy); } { - var settings = new ElserInternalServiceSettings(1, 1, ".elser_model_1"); + var settings = new ElserInternalServiceSettings(1, 1, ".elser_model_1", null); var copy = copyInstance(settings, TransportVersions.V_8_11_X); assertEquals(settings, copy); } @@ -123,12 +124,27 @@ protected ElserInternalServiceSettings createTestInstance() { @Override protected ElserInternalServiceSettings mutateInstance(ElserInternalServiceSettings instance) { return switch (randomIntBetween(0, 2)) { - case 0 -> new ElserInternalServiceSettings(instance.getNumAllocations() + 1, instance.getNumThreads(), instance.getModelId()); - case 1 -> new ElserInternalServiceSettings(instance.getNumAllocations(), instance.getNumThreads() + 1, instance.getModelId()); + case 0 -> new ElserInternalServiceSettings( + instance.getNumAllocations() + 1, + instance.getNumThreads(), + instance.getModelId(), + null + ); + case 1 -> new ElserInternalServiceSettings( + instance.getNumAllocations(), + instance.getNumThreads() + 1, + instance.getModelId(), + null + ); case 2 -> { var versions = new HashSet<>(ElserInternalService.VALID_ELSER_MODEL_IDS); versions.remove(instance.getModelId()); - yield new ElserInternalServiceSettings(instance.getNumAllocations(), instance.getNumThreads(), versions.iterator().next()); + yield new ElserInternalServiceSettings( + instance.getNumAllocations(), + instance.getNumThreads(), + versions.iterator().next(), + null + ); } default -> throw new IllegalStateException(); }; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java index bc7dca4f11960..5ee55003e7fe1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalServiceTests.java @@ -108,7 +108,7 @@ public void testParseConfigStrict() { "foo", TaskType.SPARSE_EMBEDDING, ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ".elser_model_1"), + new ElserInternalServiceSettings(1, 4, ".elser_model_1", null), ElserMlNodeTaskSettings.DEFAULT ); @@ -141,7 +141,7 @@ public void testParseConfigLooseWithOldModelId() { "foo", TaskType.SPARSE_EMBEDDING, ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ".elser_model_1"), + new ElserInternalServiceSettings(1, 4, ".elser_model_1", null), ElserMlNodeTaskSettings.DEFAULT ); @@ -171,7 +171,7 @@ public void testParseConfigStrictWithNoTaskSettings() { "foo", TaskType.SPARSE_EMBEDDING, ElserInternalService.NAME, - new ElserInternalServiceSettings(1, 4, ElserInternalService.ELSER_V2_MODEL), + new ElserInternalServiceSettings(1, 4, ElserInternalService.ELSER_V2_MODEL, null), ElserMlNodeTaskSettings.DEFAULT ); @@ -373,7 +373,7 @@ public void testChunkInfer() { "foo", TaskType.SPARSE_EMBEDDING, "elser", - new ElserInternalServiceSettings(1, 1, "elser"), + new ElserInternalServiceSettings(1, 1, "elser", null), new ElserMlNodeTaskSettings() ); var service = createService(client); @@ -437,7 +437,7 @@ public void testChunkInferSetsTokenization() { "foo", TaskType.SPARSE_EMBEDDING, "elser", - new ElserInternalServiceSettings(1, 1, "elser"), + new ElserInternalServiceSettings(1, 1, "elser", null), new ElserMlNodeTaskSettings() ); var service = createService(client); @@ -489,7 +489,7 @@ public void testPutModel() { "my-elser", TaskType.SPARSE_EMBEDDING, "elser", - new ElserInternalServiceSettings(1, 1, ".elser_model_2"), + new ElserInternalServiceSettings(1, 1, ".elser_model_2", null), ElserMlNodeTaskSettings.DEFAULT ); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java index 076986acdcee6..009a6dbdeb793 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java @@ -8,6 +8,7 @@ package org.elasticsearch.xpack.inference.services.mistral.embeddings; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.io.stream.ByteArrayStreamInput; import org.elasticsearch.common.io.stream.BytesStreamOutput; import org.elasticsearch.core.Nullable; @@ -27,6 +28,7 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.is; public class MistralEmbeddingsServiceSettingsTests extends ESTestCase { @@ -77,6 +79,84 @@ public void testFromMap_PersistentContext_DoesNotThrowException_WhenDimensionsIs assertThat(serviceSettings, is(new MistralEmbeddingsServiceSettings(model, null, null, null, null))); } + public void testFromMap_ThrowsException_WhenDimensionsAreZero() { + var model = "mistral-embed"; + var dimensions = 0; + + var settingsMap = createRequestSettingsMap(model, dimensions, null, SimilarityMeasure.COSINE); + + var thrownException = expectThrows( + ValidationException.class, + () -> MistralEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_ThrowsException_WhenDimensionsAreNegative() { + var model = "mistral-embed"; + var dimensions = randomNegativeInt(); + + var settingsMap = createRequestSettingsMap(model, dimensions, null, SimilarityMeasure.COSINE); + + var thrownException = expectThrows( + ValidationException.class, + () -> MistralEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [dimensions] must be a positive integer;", + dimensions + ) + ) + ); + } + + public void testFromMap_ThrowsException_WhenMaxInputTokensAreZero() { + var model = "mistral-embed"; + var maxInputTokens = 0; + + var settingsMap = createRequestSettingsMap(model, null, maxInputTokens, SimilarityMeasure.COSINE); + + var thrownException = expectThrows( + ValidationException.class, + () -> MistralEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_ThrowsException_WhenMaxInputTokensAreNegative() { + var model = "mistral-embed"; + var maxInputTokens = randomNegativeInt(); + + var settingsMap = createRequestSettingsMap(model, null, maxInputTokens, SimilarityMeasure.COSINE); + + var thrownException = expectThrows( + ValidationException.class, + () -> MistralEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [max_input_tokens] must be a positive integer;", + maxInputTokens + ) + ) + ); + } + public void testFromMap_PersistentContext_DoesNotThrowException_WhenSimilarityIsPresent() { var model = "mistral-embed"; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 9e35180547bf2..9ff175ca9685e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -342,7 +342,7 @@ public void testParseRequestConfig_MovesModel() throws IOException { public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModel() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", 100, false), + getServiceSettingsMap("model", "url", "org", 100, null, false), getTaskSettingsMap("user"), getSecretSettingsMap("secret") ); @@ -393,7 +393,7 @@ public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidM public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlOrganization() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", null, null, null, true), + getServiceSettingsMap("model", null, null, null, null, true), getTaskSettingsMap(null), getSecretSettingsMap("secret") ); @@ -419,7 +419,7 @@ public void testParsePersistedConfigWithSecrets_CreatesAnOpenAiEmbeddingsModelWi public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, true), + getServiceSettingsMap("model", "url", "org", null, null, true), getTaskSettingsMap("user"), getSecretSettingsMap("secret") ); @@ -450,7 +450,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists secretSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, true), + getServiceSettingsMap("model", "url", "org", null, null, true), getTaskSettingsMap("user"), secretSettingsMap ); @@ -476,7 +476,7 @@ public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExists public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSecrets() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, true), + getServiceSettingsMap("model", "url", "org", null, null, true), getTaskSettingsMap("user"), getSecretSettingsMap("secret") ); @@ -503,7 +503,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInSe public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createOpenAiService()) { - var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, true); + var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, null, true); serviceSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMap("user"), getSecretSettingsMap("secret")); @@ -532,7 +532,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa taskSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, true), + getServiceSettingsMap("model", "url", "org", null, null, true), taskSettingsMap, getSecretSettingsMap("secret") ); @@ -558,7 +558,7 @@ public void testParsePersistedConfigWithSecrets_NotThrowWhenAnExtraKeyExistsInTa public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModel() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, true), + getServiceSettingsMap("model", "url", "org", null, null, true), getTaskSettingsMap("user") ); @@ -593,7 +593,10 @@ public void testParsePersistedConfig_ThrowsErrorTryingToParseInvalidModel() thro public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUrlOrganization() throws IOException { try (var service = createOpenAiService()) { - var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", null, null, null, true), getTaskSettingsMap(null)); + var persistedConfig = getPersistedConfigMap( + getServiceSettingsMap("model", null, null, null, null, true), + getTaskSettingsMap(null) + ); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); @@ -611,7 +614,7 @@ public void testParsePersistedConfig_CreatesAnOpenAiEmbeddingsModelWithoutUserUr public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { try (var service = createOpenAiService()) { var persistedConfig = getPersistedConfigMap( - getServiceSettingsMap("model", "url", "org", null, true), + getServiceSettingsMap("model", "url", "org", null, null, true), getTaskSettingsMap("user") ); persistedConfig.config().put("extra_key", "value"); @@ -631,7 +634,7 @@ public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInServiceSettings() throws IOException { try (var service = createOpenAiService()) { - var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, true); + var serviceSettingsMap = getServiceSettingsMap("model", "url", "org", null, null, true); serviceSettingsMap.put("extra_key", "value"); var persistedConfig = getPersistedConfigMap(serviceSettingsMap, getTaskSettingsMap("user")); @@ -654,7 +657,7 @@ public void testParsePersistedConfig_NotThrowWhenAnExtraKeyExistsInTaskSettings( var taskSettingsMap = getTaskSettingsMap("user"); taskSettingsMap.put("extra_key", "value"); - var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org", null, true), taskSettingsMap); + var persistedConfig = getPersistedConfigMap(getServiceSettingsMap("model", "url", "org", null, null, true), taskSettingsMap); var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, persistedConfig.config()); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java index cc0004a2d678c..10ccbb4eb39f6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/embeddings/OpenAiEmbeddingsServiceSettingsTests.java @@ -257,6 +257,92 @@ public void testFromMap_PersistentContext_DoesNotThrowException_WhenDimensionsIs assertThat(settings, is(new OpenAiEmbeddingsServiceSettings("m", (URI) null, null, null, null, null, true, null))); } + public void testFromMap_ThrowsException_WhenDimensionsAreZero() { + var modelId = "model-foo"; + var url = "https://www.abc.com"; + var org = "organization"; + var dimensions = 0; + + var settingsMap = getServiceSettingsMap(modelId, url, org, dimensions, null, null); + + var thrownException = expectThrows( + ValidationException.class, + () -> OpenAiEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [dimensions] must be a positive integer;") + ); + } + + public void testFromMap_ThrowsException_WhenDimensionsAreNegative() { + var modelId = "model-foo"; + var url = "https://www.abc.com"; + var org = "organization"; + var dimensions = randomNegativeInt(); + + var settingsMap = getServiceSettingsMap(modelId, url, org, dimensions, null, null); + + var thrownException = expectThrows( + ValidationException.class, + () -> OpenAiEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [dimensions] must be a positive integer;", + dimensions + ) + ) + ); + } + + public void testFromMap_ThrowsException_WhenMaxInputTokensAreZero() { + var modelId = "model-foo"; + var url = "https://www.abc.com"; + var org = "organization"; + var maxInputTokens = 0; + + var settingsMap = getServiceSettingsMap(modelId, url, org, null, maxInputTokens, null); + + var thrownException = expectThrows( + ValidationException.class, + () -> OpenAiEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString("Validation Failed: 1: [service_settings] Invalid value [0]. [max_input_tokens] must be a positive integer;") + ); + } + + public void testFromMap_ThrowsException_WhenMaxInputTokensAreNegative() { + var modelId = "model-foo"; + var url = "https://www.abc.com"; + var org = "organization"; + var maxInputTokens = randomNegativeInt(); + + var settingsMap = getServiceSettingsMap(modelId, url, org, null, maxInputTokens, null); + + var thrownException = expectThrows( + ValidationException.class, + () -> OpenAiEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST) + ); + + assertThat( + thrownException.getMessage(), + containsString( + Strings.format( + "Validation Failed: 1: [service_settings] Invalid value [%d]. [max_input_tokens] must be a positive integer;", + maxInputTokens + ) + ) + ); + } + public void testFromMap_PersistentContext_DoesNotThrowException_WhenDimensionsSetByUserIsNull() { OpenAiEmbeddingsServiceSettings.fromMap( new HashMap<>(Map.of(ServiceFields.DIMENSIONS, 1, ServiceFields.MODEL_ID, "m")), @@ -464,6 +550,7 @@ public static Map getServiceSettingsMap( @Nullable String url, @Nullable String org, @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, @Nullable Boolean dimensionsSetByUser ) { var map = new HashMap(); @@ -481,6 +568,10 @@ public static Map getServiceSettingsMap( map.put(ServiceFields.DIMENSIONS, dimensions); } + if (maxInputTokens != null) { + map.put(ServiceFields.MAX_INPUT_TOKENS, maxInputTokens); + } + if (dimensionsSetByUser != null) { map.put(OpenAiEmbeddingsServiceSettings.DIMENSIONS_SET_BY_USER, dimensionsSetByUser); } diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml index fd656c9d5d950..f6a7073914609 100644 --- a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/50_semantic_text_query_inference_endpoint_changes.yml @@ -81,6 +81,7 @@ setup: - do: inference.delete: inference_id: sparse-inference-id + force: true - do: inference.put: @@ -119,6 +120,7 @@ setup: - do: inference.delete: inference_id: dense-inference-id + force: true - do: inference.put: @@ -155,6 +157,7 @@ setup: - do: inference.delete: inference_id: dense-inference-id + force: true - do: inference.put: diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java index 8cb13398a70ae..fec85730aaf2b 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/DetectionRulesIT.java @@ -95,6 +95,9 @@ public void testCondition() throws Exception { closeJob(job.getId()); List records = getRecords(job.getId()); + // remove records that are not anomalies + records.removeIf(record -> record.getInitialRecordScore() < 1e-5); + assertThat(records.size(), equalTo(1)); assertThat(records.get(0).getByFieldValue(), equalTo("high")); long firstRecordTimestamp = records.get(0).getTimestamp().getTime(); diff --git a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlInitializationServiceIT.java b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlInitializationServiceIT.java index 30f84a97bcfb0..1d67639f712a0 100644 --- a/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlInitializationServiceIT.java +++ b/x-pack/plugin/ml/qa/native-multi-node-tests/src/javaRestTest/java/org/elasticsearch/xpack/ml/integration/MlInitializationServiceIT.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.ml.MachineLearning; import org.elasticsearch.xpack.ml.MlDailyMaintenanceService; import org.elasticsearch.xpack.ml.MlInitializationService; +import org.elasticsearch.xpack.ml.inference.adaptiveallocations.AdaptiveAllocationsScalerService; import org.junit.Before; import java.util.List; @@ -47,7 +48,14 @@ public void setUpMocks() { when(threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME)).thenReturn(EsExecutors.DIRECT_EXECUTOR_SERVICE); MlDailyMaintenanceService mlDailyMaintenanceService = mock(MlDailyMaintenanceService.class); ClusterService clusterService = mock(ClusterService.class); - mlInitializationService = new MlInitializationService(client(), threadPool, mlDailyMaintenanceService, clusterService); + AdaptiveAllocationsScalerService adaptiveAllocationsScalerService = mock(AdaptiveAllocationsScalerService.class); + mlInitializationService = new MlInitializationService( + client(), + threadPool, + mlDailyMaintenanceService, + adaptiveAllocationsScalerService, + clusterService + ); } public void testThatMlIndicesBecomeHiddenWhenTheNodeBecomesMaster() throws Exception { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java index 6fdc4e73e184f..22a9c2dbcc281 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MachineLearning.java @@ -1282,6 +1282,7 @@ public Collection createComponents(PluginServices services) { threadPool, clusterService, client, + inferenceAuditor, mlAssignmentNotifier, machineLearningExtension.get().isAnomalyDetectionEnabled(), machineLearningExtension.get().isDataFrameAnalyticsEnabled(), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlInitializationService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlInitializationService.java index a2d8fd1d60316..a1664b7023fc0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlInitializationService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/MlInitializationService.java @@ -32,6 +32,9 @@ import org.elasticsearch.gateway.GatewayService; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.core.ml.annotations.AnnotationIndex; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsFeatureFlag; +import org.elasticsearch.xpack.ml.inference.adaptiveallocations.AdaptiveAllocationsScalerService; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import java.util.Collections; import java.util.Map; @@ -55,6 +58,8 @@ public final class MlInitializationService implements ClusterStateListener { private final MlDailyMaintenanceService mlDailyMaintenanceService; + private final AdaptiveAllocationsScalerService adaptiveAllocationsScalerService; + private boolean isMaster = false; MlInitializationService( @@ -62,6 +67,7 @@ public final class MlInitializationService implements ClusterStateListener { ThreadPool threadPool, ClusterService clusterService, Client client, + InferenceAuditor inferenceAuditor, MlAssignmentNotifier mlAssignmentNotifier, boolean isAnomalyDetectionEnabled, boolean isDataFrameAnalyticsEnabled, @@ -81,6 +87,7 @@ public final class MlInitializationService implements ClusterStateListener { isDataFrameAnalyticsEnabled, isNlpEnabled ), + new AdaptiveAllocationsScalerService(threadPool, clusterService, client, inferenceAuditor, isNlpEnabled), clusterService ); } @@ -90,11 +97,13 @@ public MlInitializationService( Client client, ThreadPool threadPool, MlDailyMaintenanceService dailyMaintenanceService, + AdaptiveAllocationsScalerService adaptiveAllocationsScalerService, ClusterService clusterService ) { this.client = Objects.requireNonNull(client); this.threadPool = threadPool; this.mlDailyMaintenanceService = dailyMaintenanceService; + this.adaptiveAllocationsScalerService = adaptiveAllocationsScalerService; clusterService.addListener(this); clusterService.addLifecycleListener(new LifecycleListener() { @Override @@ -115,11 +124,17 @@ public void beforeStop() { public void onMaster() { mlDailyMaintenanceService.start(); + if (AdaptiveAllocationsFeatureFlag.isEnabled()) { + adaptiveAllocationsScalerService.start(); + } threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(this::makeMlInternalIndicesHidden); } public void offMaster() { mlDailyMaintenanceService.stop(); + if (AdaptiveAllocationsFeatureFlag.isEnabled()) { + adaptiveAllocationsScalerService.stop(); + } } @Override diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAssignmentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAssignmentAction.java index 348cb396f9c9f..30371fcbe115a 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAssignmentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportCreateTrainedModelAssignmentAction.java @@ -75,7 +75,7 @@ public TransportCreateTrainedModelAssignmentAction( @Override protected void masterOperation(Task task, Request request, ClusterState state, ActionListener listener) throws Exception { trainedModelAssignmentClusterService.createNewModelAssignment( - request.getTaskParams(), + request, listener.delegateFailureAndWrap((l, trainedModelAssignment) -> l.onResponse(new Response(trainedModelAssignment))) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java index 04b597292dad6..590aeded2b674 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsAction.java @@ -238,6 +238,7 @@ static GetDeploymentStatsAction.Response addFailedRoutes( stat.getModelId(), stat.getThreadsPerAllocation(), stat.getNumberOfAllocations(), + stat.getAdaptiveAllocationsSettings(), stat.getQueueCapacity(), stat.getCacheSize(), stat.getStartTime(), @@ -277,6 +278,7 @@ static GetDeploymentStatsAction.Response addFailedRoutes( assignment.getModelId(), assignment.getTaskParams().getThreadsPerAllocation(), assignment.getTaskParams().getNumberOfAllocations(), + assignment.getAdaptiveAllocationsSettings(), assignment.getTaskParams().getQueueCapacity(), assignment.getTaskParams().getCacheSize().orElse(null), assignment.getStartTime(), @@ -346,6 +348,7 @@ protected void taskOperation( task.getParams().getModelId(), task.getParams().getThreadsPerAllocation(), assignment == null ? task.getParams().getNumberOfAllocations() : assignment.getTaskParams().getNumberOfAllocations(), + assignment == null ? null : assignment.getAdaptiveAllocationsSettings(), task.getParams().getQueueCapacity(), task.getParams().getCacheSize().orElse(null), TrainedModelAssignmentMetadata.fromState(clusterService.state()) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java index de93a41fb7296..ae0da7dc9cc69 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportStartTrainedModelDeploymentAction.java @@ -207,7 +207,7 @@ protected void masterOperation( modelIdAndSizeInBytes.v1(), request.getDeploymentId(), modelIdAndSizeInBytes.v2(), - request.getNumberOfAllocations(), + request.computeNumberOfAllocations(), request.getThreadsPerAllocation(), request.getQueueCapacity(), Optional.ofNullable(request.getCacheSize()).orElse(ByteSizeValue.ofBytes(modelIdAndSizeInBytes.v2())), @@ -219,7 +219,10 @@ protected void masterOperation( memoryTracker.refresh( persistentTasks, ActionListener.wrap( - aVoid -> trainedModelAssignmentService.createNewModelAssignment(taskParams, waitForDeploymentToStart), + aVoid -> trainedModelAssignmentService.createNewModelAssignment( + new CreateTrainedModelAssignmentAction.Request(taskParams, request.getAdaptiveAllocationsSettings()), + waitForDeploymentToStart + ), listener::onFailure ) ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelDeploymentAction.java index 7d4143d9e722a..fa38b30ae8b84 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/action/TransportUpdateTrainedModelDeploymentAction.java @@ -81,9 +81,11 @@ protected void masterOperation( ) ); - trainedModelAssignmentClusterService.updateNumberOfAllocations( + trainedModelAssignmentClusterService.updateDeployment( request.getDeploymentId(), request.getNumberOfAllocations(), + request.getAdaptiveAllocationsSettings(), + request.isInternal(), ActionListener.wrap(updatedAssignment -> { auditor.info( request.getDeploymentId(), diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScaler.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScaler.java new file mode 100644 index 0000000000000..b33e86d434f95 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScaler.java @@ -0,0 +1,154 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.adaptiveallocations; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Strings; + +/** + * Processes measured requests counts and inference times and decides whether + * the number of allocations should be scaled up or down. + */ +public class AdaptiveAllocationsScaler { + + // visible for testing + static final double SCALE_UP_THRESHOLD = 0.9; + private static final double SCALE_DOWN_THRESHOLD = 0.85; + + private static final Logger logger = LogManager.getLogger(AdaptiveAllocationsScaler.class); + + private final String deploymentId; + private final KalmanFilter1d requestRateEstimator; + private final KalmanFilter1d inferenceTimeEstimator; + + private int numberOfAllocations; + private Integer minNumberOfAllocations; + private Integer maxNumberOfAllocations; + private boolean dynamicsChanged; + + AdaptiveAllocationsScaler(String deploymentId, int numberOfAllocations) { + this.deploymentId = deploymentId; + // A smoothing factor of 100 roughly means the last 100 measurements have an effect + // on the estimated values. The sampling time is 10 seconds, so approximately the + // last 15 minutes are taken into account. + // For the request rate, use auto-detection for dynamics changes, because the request + // rate maybe change due to changed user behaviour. + // For the inference time, don't use this auto-detection. The dynamics may change when + // the number of allocations changes, which is passed explicitly to the estimator. + requestRateEstimator = new KalmanFilter1d(deploymentId + ":rate", 100, true); + inferenceTimeEstimator = new KalmanFilter1d(deploymentId + ":time", 100, false); + this.numberOfAllocations = numberOfAllocations; + this.minNumberOfAllocations = null; + this.maxNumberOfAllocations = null; + this.dynamicsChanged = false; + } + + void setMinMaxNumberOfAllocations(Integer minNumberOfAllocations, Integer maxNumberOfAllocations) { + this.minNumberOfAllocations = minNumberOfAllocations; + this.maxNumberOfAllocations = maxNumberOfAllocations; + } + + void process(AdaptiveAllocationsScalerService.Stats stats, double timeIntervalSeconds, int numberOfAllocations) { + // The request rate (per second) is the request count divided by the time. + // Assuming a Poisson process for the requests, the variance in the request + // count equals the mean request count, and the variance in the request rate + // equals that variance divided by the time interval squared. + // The minimum request count is set to 1, because lower request counts can't + // be reliably measured. + // The estimated request rate should be used for the variance calculations, + // because the measured request rate gives biased estimates. + double requestRate = (double) stats.requestCount() / timeIntervalSeconds; + double requestRateEstimate = requestRateEstimator.hasValue() ? requestRateEstimator.estimate() : requestRate; + double requestRateVariance = Math.max(1.0, requestRateEstimate * timeIntervalSeconds) / Math.pow(timeIntervalSeconds, 2); + requestRateEstimator.add(requestRate, requestRateVariance, false); + + if (stats.requestCount() > 0 && Double.isNaN(stats.inferenceTime()) == false) { + // The inference time distribution is unknown. For simplicity, we assume + // a std.error equal to the mean, so that the variance equals the mean + // value squared. The variance of the mean is inversely proportional to + // the number of inference measurements it contains. + // Again, the estimated inference time should be used for the variance + // calculations to prevent biased estimates. + double inferenceTime = stats.inferenceTime(); + double inferenceTimeEstimate = inferenceTimeEstimator.hasValue() ? inferenceTimeEstimator.estimate() : inferenceTime; + double inferenceTimeVariance = Math.pow(inferenceTimeEstimate, 2) / stats.requestCount(); + inferenceTimeEstimator.add(inferenceTime, inferenceTimeVariance, dynamicsChanged); + } + + this.numberOfAllocations = numberOfAllocations; + dynamicsChanged = false; + } + + double getLoadLower() { + double requestRateLower = Math.max(0.0, requestRateEstimator.lower()); + double inferenceTimeLower = Math.max(0.0, inferenceTimeEstimator.hasValue() ? inferenceTimeEstimator.lower() : 1.0); + return requestRateLower * inferenceTimeLower; + } + + double getLoadUpper() { + double requestRateUpper = requestRateEstimator.upper(); + double inferenceTimeUpper = inferenceTimeEstimator.hasValue() ? inferenceTimeEstimator.upper() : 1.0; + return requestRateUpper * inferenceTimeUpper; + } + + Integer scale() { + if (requestRateEstimator.hasValue() == false) { + return null; + } + + int oldNumberOfAllocations = numberOfAllocations; + + double loadLower = getLoadLower(); + while (loadLower / numberOfAllocations > SCALE_UP_THRESHOLD) { + numberOfAllocations++; + } + + double loadUpper = getLoadUpper(); + while (numberOfAllocations > 1 && loadUpper / (numberOfAllocations - 1) < SCALE_DOWN_THRESHOLD) { + numberOfAllocations--; + } + + if (minNumberOfAllocations != null) { + numberOfAllocations = Math.max(numberOfAllocations, minNumberOfAllocations); + } + if (maxNumberOfAllocations != null) { + numberOfAllocations = Math.min(numberOfAllocations, maxNumberOfAllocations); + } + + if (numberOfAllocations != oldNumberOfAllocations) { + logger.debug( + () -> Strings.format( + "[%s] adaptive allocations scaler: load in [%.3f, %.3f], scaling from %d to %d allocations.", + deploymentId, + loadLower, + loadUpper, + oldNumberOfAllocations, + numberOfAllocations + ) + ); + } else { + logger.debug( + () -> Strings.format( + "[%s] adaptive allocations scaler: load in [%.3f, %.3f], keeping %d allocations.", + deploymentId, + loadLower, + loadUpper, + numberOfAllocations + ) + ); + } + + if (numberOfAllocations != oldNumberOfAllocations) { + this.dynamicsChanged = true; + return numberOfAllocations; + } else { + return null; + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java new file mode 100644 index 0000000000000..30e3871ad5ad0 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerService.java @@ -0,0 +1,340 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.adaptiveallocations; + +import org.apache.logging.log4j.Level; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.ClusterStateListener; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.threadpool.Scheduler; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.core.ClientHelper; +import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Periodically schedules adaptive allocations scaling. This process consists + * of calling the trained model stats API, processing the results, determining + * whether scaling should be applied, and potentially calling the trained + * model update API. + */ +public class AdaptiveAllocationsScalerService implements ClusterStateListener { + + record Stats(long successCount, long pendingCount, long failedCount, double inferenceTime) { + + long requestCount() { + return successCount + pendingCount + failedCount; + } + + double totalInferenceTime() { + return successCount * inferenceTime; + } + + Stats add(Stats value) { + long newSuccessCount = successCount + value.successCount; + long newPendingCount = pendingCount + value.pendingCount; + long newFailedCount = failedCount + value.failedCount; + double newInferenceTime = newSuccessCount > 0 + ? (totalInferenceTime() + value.totalInferenceTime()) / newSuccessCount + : Double.NaN; + return new Stats(newSuccessCount, newPendingCount, newFailedCount, newInferenceTime); + } + + Stats sub(Stats value) { + long newSuccessCount = Math.max(0, successCount - value.successCount); + long newPendingCount = Math.max(0, pendingCount - value.pendingCount); + long newFailedCount = Math.max(0, failedCount - value.failedCount); + double newInferenceTime = newSuccessCount > 0 + ? (totalInferenceTime() - value.totalInferenceTime()) / newSuccessCount + : Double.NaN; + return new Stats(newSuccessCount, newPendingCount, newFailedCount, newInferenceTime); + } + } + + /** + * The time interval between the adaptive allocations triggers. + */ + private static final int DEFAULT_TIME_INTERVAL_SECONDS = 10; + /** + * The time that has to pass after scaling up, before scaling down is allowed. + * Note that the ML autoscaling has its own cooldown time to release the hardware. + */ + private static final long SCALE_UP_COOLDOWN_TIME_MILLIS = TimeValue.timeValueMinutes(5).getMillis(); + + private static final Logger logger = LogManager.getLogger(AdaptiveAllocationsScalerService.class); + + private final int timeIntervalSeconds; + private final ThreadPool threadPool; + private final ClusterService clusterService; + private final Client client; + private final InferenceAuditor inferenceAuditor; + private final boolean isNlpEnabled; + private final Map> lastInferenceStatsByDeploymentAndNode; + private Long lastInferenceStatsTimestampMillis; + private final Map scalers; + private final Map lastScaleUpTimesMillis; + + private volatile Scheduler.Cancellable cancellable; + private final AtomicBoolean busy; + + public AdaptiveAllocationsScalerService( + ThreadPool threadPool, + ClusterService clusterService, + Client client, + InferenceAuditor inferenceAuditor, + boolean isNlpEnabled + ) { + this(threadPool, clusterService, client, inferenceAuditor, isNlpEnabled, DEFAULT_TIME_INTERVAL_SECONDS); + } + + // visible for testing + AdaptiveAllocationsScalerService( + ThreadPool threadPool, + ClusterService clusterService, + Client client, + InferenceAuditor inferenceAuditor, + boolean isNlpEnabled, + int timeIntervalSeconds + ) { + this.threadPool = threadPool; + this.clusterService = clusterService; + this.client = client; + this.inferenceAuditor = inferenceAuditor; + this.isNlpEnabled = isNlpEnabled; + this.timeIntervalSeconds = timeIntervalSeconds; + + lastInferenceStatsByDeploymentAndNode = new HashMap<>(); + lastInferenceStatsTimestampMillis = null; + lastScaleUpTimesMillis = new HashMap<>(); + scalers = new HashMap<>(); + busy = new AtomicBoolean(false); + } + + public synchronized void start() { + updateAutoscalers(clusterService.state()); + clusterService.addListener(this); + if (scalers.isEmpty() == false) { + startScheduling(); + } + } + + public synchronized void stop() { + stopScheduling(); + } + + @Override + public void clusterChanged(ClusterChangedEvent event) { + updateAutoscalers(event.state()); + if (scalers.isEmpty() == false) { + startScheduling(); + } else { + stopScheduling(); + } + } + + private synchronized void updateAutoscalers(ClusterState state) { + if (isNlpEnabled == false) { + return; + } + Set deploymentIds = new HashSet<>(); + TrainedModelAssignmentMetadata assignments = TrainedModelAssignmentMetadata.fromState(state); + for (TrainedModelAssignment assignment : assignments.allAssignments().values()) { + deploymentIds.add(assignment.getDeploymentId()); + if (assignment.getAdaptiveAllocationsSettings() != null && assignment.getAdaptiveAllocationsSettings().getEnabled()) { + AdaptiveAllocationsScaler adaptiveAllocationsScaler = scalers.computeIfAbsent( + assignment.getDeploymentId(), + key -> new AdaptiveAllocationsScaler(assignment.getDeploymentId(), assignment.totalTargetAllocations()) + ); + adaptiveAllocationsScaler.setMinMaxNumberOfAllocations( + assignment.getAdaptiveAllocationsSettings().getMinNumberOfAllocations(), + assignment.getAdaptiveAllocationsSettings().getMaxNumberOfAllocations() + ); + } else { + scalers.remove(assignment.getDeploymentId()); + lastInferenceStatsByDeploymentAndNode.remove(assignment.getDeploymentId()); + } + } + scalers.keySet().removeIf(key -> deploymentIds.contains(key) == false); + } + + private synchronized void startScheduling() { + if (cancellable == null) { + logger.debug("Starting ML adaptive allocations scaler"); + try { + cancellable = threadPool.scheduleWithFixedDelay( + this::trigger, + TimeValue.timeValueSeconds(timeIntervalSeconds), + threadPool.generic() + ); + } catch (EsRejectedExecutionException e) { + if (e.isExecutorShutdown() == false) { + throw e; + } + } + } + } + + private synchronized void stopScheduling() { + if (cancellable != null && cancellable.isCancelled() == false) { + logger.debug("Stopping ML adaptive allocations scaler"); + cancellable.cancel(); + cancellable = null; + } + } + + private void trigger() { + if (busy.getAndSet(true)) { + logger.debug("Skipping inference adaptive allocations scaling, because it's still busy."); + return; + } + ActionListener listener = ActionListener.runAfter( + ActionListener.wrap(this::processDeploymentStats, e -> logger.warn("Error in inference adaptive allocations scaling", e)), + () -> busy.set(false) + ); + getDeploymentStats(listener); + } + + private void getDeploymentStats(ActionListener processDeploymentStats) { + String deploymentIds = String.join(",", scalers.keySet()); + ClientHelper.executeAsyncWithOrigin( + client, + ClientHelper.ML_ORIGIN, + GetDeploymentStatsAction.INSTANCE, + // TODO(dave/jan): create a lightweight version of this request, because the current one + // collects too much data for the adaptive allocations scaler. + new GetDeploymentStatsAction.Request(deploymentIds), + processDeploymentStats + ); + } + + private void processDeploymentStats(GetDeploymentStatsAction.Response statsResponse) { + Double statsTimeInterval; + long now = System.currentTimeMillis(); + if (lastInferenceStatsTimestampMillis != null) { + statsTimeInterval = (now - lastInferenceStatsTimestampMillis) / 1000.0; + } else { + statsTimeInterval = null; + } + lastInferenceStatsTimestampMillis = now; + + Map recentStatsByDeployment = new HashMap<>(); + Map numberOfAllocations = new HashMap<>(); + + for (AssignmentStats assignmentStats : statsResponse.getStats().results()) { + String deploymentId = assignmentStats.getDeploymentId(); + numberOfAllocations.put(deploymentId, assignmentStats.getNumberOfAllocations()); + Map deploymentStats = lastInferenceStatsByDeploymentAndNode.computeIfAbsent( + deploymentId, + key -> new HashMap<>() + ); + for (AssignmentStats.NodeStats nodeStats : assignmentStats.getNodeStats()) { + String nodeId = nodeStats.getNode().getId(); + Stats lastStats = deploymentStats.get(nodeId); + Stats nextStats = new Stats( + nodeStats.getInferenceCount().orElse(0L), + nodeStats.getPendingCount() == null ? 0 : nodeStats.getPendingCount(), + nodeStats.getErrorCount() + nodeStats.getTimeoutCount() + nodeStats.getRejectedExecutionCount(), + nodeStats.getAvgInferenceTime().orElse(0.0) / 1000.0 + ); + deploymentStats.put(nodeId, nextStats); + if (lastStats != null) { + Stats recentStats = nextStats.sub(lastStats); + recentStatsByDeployment.compute( + assignmentStats.getDeploymentId(), + (key, value) -> value == null ? recentStats : value.add(recentStats) + ); + } + } + } + + if (statsTimeInterval == null) { + return; + } + + for (Map.Entry deploymentAndStats : recentStatsByDeployment.entrySet()) { + String deploymentId = deploymentAndStats.getKey(); + Stats stats = deploymentAndStats.getValue(); + AdaptiveAllocationsScaler adaptiveAllocationsScaler = scalers.get(deploymentId); + adaptiveAllocationsScaler.process(stats, statsTimeInterval, numberOfAllocations.get(deploymentId)); + Integer newNumberOfAllocations = adaptiveAllocationsScaler.scale(); + if (newNumberOfAllocations != null) { + Long lastScaleUpTimeMillis = lastScaleUpTimesMillis.get(deploymentId); + if (newNumberOfAllocations < numberOfAllocations.get(deploymentId) + && lastScaleUpTimeMillis != null + && now < lastScaleUpTimeMillis + SCALE_UP_COOLDOWN_TIME_MILLIS) { + logger.debug("adaptive allocations scaler: skipping scaling down [{}] because of recent scaleup.", deploymentId); + continue; + } + if (newNumberOfAllocations > numberOfAllocations.get(deploymentId)) { + lastScaleUpTimesMillis.put(deploymentId, now); + } + UpdateTrainedModelDeploymentAction.Request updateRequest = new UpdateTrainedModelDeploymentAction.Request(deploymentId); + updateRequest.setNumberOfAllocations(newNumberOfAllocations); + updateRequest.setIsInternal(true); + ClientHelper.executeAsyncWithOrigin( + client, + ClientHelper.ML_ORIGIN, + UpdateTrainedModelDeploymentAction.INSTANCE, + updateRequest, + ActionListener.wrap(updateResponse -> { + logger.info("adaptive allocations scaler: scaled [{}] to [{}] allocations.", deploymentId, newNumberOfAllocations); + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME) + .execute( + () -> inferenceAuditor.info( + deploymentId, + Strings.format( + "adaptive allocations scaler: scaled [%s] to [%s] allocations.", + deploymentId, + newNumberOfAllocations + ) + ) + ); + }, e -> { + logger.atLevel(Level.WARN) + .withThrowable(e) + .log( + "adaptive allocations scaler: scaling [{}] to [{}] allocations failed.", + deploymentId, + newNumberOfAllocations + ); + threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME) + .execute( + () -> inferenceAuditor.warning( + deploymentId, + Strings.format( + "adaptive allocations scaler: scaling [%s] to [%s] allocations failed.", + deploymentId, + newNumberOfAllocations + ) + ) + ); + }) + ); + } + } + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/KalmanFilter1d.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/KalmanFilter1d.java new file mode 100644 index 0000000000000..ad3e66fc3e8e2 --- /dev/null +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/KalmanFilter1d.java @@ -0,0 +1,121 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.adaptiveallocations; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.Strings; + +/** + * Estimator for the mean value and stderr of a series of measurements. + *
+ * This implements a 1d Kalman filter with manoeuvre detection. Rather than a derived + * dynamics model we simply fix how much we want to smooth in the steady state. + * See also: Wikipedia. + */ +class KalmanFilter1d { + + private static final Logger logger = LogManager.getLogger(KalmanFilter1d.class); + + private final String name; + private final double smoothingFactor; + private final boolean autodetectDynamicsChange; + + private double value; + private double variance; + private boolean dynamicsChangedLastTime; + + KalmanFilter1d(String name, double smoothingFactor, boolean autodetectDynamicsChange) { + this.name = name; + this.smoothingFactor = smoothingFactor; + this.autodetectDynamicsChange = autodetectDynamicsChange; + this.value = Double.MAX_VALUE; + this.variance = Double.MAX_VALUE; + this.dynamicsChangedLastTime = false; + } + + /** + * Adds a measurement (value, variance) to the estimator. + * dynamicChangedExternal indicates whether the underlying possibly changed before this measurement. + */ + void add(double value, double variance, boolean dynamicChangedExternal) { + boolean dynamicChanged; + if (hasValue() == false) { + dynamicChanged = true; + this.value = value; + this.variance = variance; + } else { + double processVariance = variance / smoothingFactor; + dynamicChanged = dynamicChangedExternal || detectDynamicsChange(value, variance); + if (dynamicChanged || dynamicsChangedLastTime) { + // If we know we likely had a change in the quantity we're estimating or the prediction + // is 10 stddev off, we inject extra noise in the dynamics for this step. + processVariance = Math.pow(value, 2); + } + + double gain = (this.variance + processVariance) / (this.variance + processVariance + variance); + this.value += gain * (value - this.value); + this.variance = (1 - gain) * (this.variance + processVariance); + } + dynamicsChangedLastTime = dynamicChanged; + logger.debug( + () -> Strings.format( + "[%s] measurement %.3f ± %.3f: estimate %.3f ± %.3f (dynamic changed: %s).", + name, + value, + Math.sqrt(variance), + this.value, + Math.sqrt(this.variance), + dynamicChanged + ) + ); + } + + /** + * Returns whether the estimator has received data and contains a value. + */ + boolean hasValue() { + return this.value < Double.MAX_VALUE && this.variance < Double.MAX_VALUE; + } + + /** + * Returns the estimate of the mean value. + */ + double estimate() { + return value; + } + + /** + * Returns the stderr of the estimate. + */ + double error() { + return Math.sqrt(this.variance); + } + + /** + * Returns the lowerbound of the 1 stddev confidence interval of the estimate. + */ + double lower() { + return value - error(); + } + + /** + * Returns the upperbound of the 1 stddev confidence interval of the estimate. + */ + double upper() { + return value + error(); + } + + /** + * Returns whether (value, variance) is very unlikely, indicating that + * the underlying dynamics have changed. + */ + private boolean detectDynamicsChange(double value, double variance) { + return hasValue() && autodetectDynamicsChange && Math.pow(Math.abs(value - this.value), 2) / (variance + this.variance) > 100.0; + } +} diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java index f468e5239fd29..e86a9cfe94045 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterService.java @@ -14,6 +14,7 @@ import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.ClusterChangedEvent; @@ -26,6 +27,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.common.util.set.Sets; @@ -38,8 +40,10 @@ import org.elasticsearch.xpack.core.ml.MachineLearningField; import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingState; @@ -68,6 +72,7 @@ import java.util.stream.Collectors; import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction.Request.NUMBER_OF_ALLOCATIONS; import static org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils.NODES_CHANGED_REASON; import static org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentUtils.createShuttingDownRoute; @@ -393,7 +398,7 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState) } public void createNewModelAssignment( - StartTrainedModelDeploymentAction.TaskParams params, + CreateTrainedModelAssignmentAction.Request request, ActionListener listener ) { if (clusterService.state().getMinTransportVersion().before(DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION)) { @@ -401,8 +406,8 @@ public void createNewModelAssignment( new ElasticsearchStatusException( "cannot create new assignment [{}] for model [{}] while cluster upgrade is in progress", RestStatus.CONFLICT, - params.getDeploymentId(), - params.getModelId() + request.getTaskParams().getDeploymentId(), + request.getTaskParams().getModelId() ) ); return; @@ -413,20 +418,20 @@ public void createNewModelAssignment( new ElasticsearchStatusException( "cannot create new assignment [{}] for model [{}] while feature reset is in progress.", RestStatus.CONFLICT, - params.getDeploymentId(), - params.getModelId() + request.getTaskParams().getDeploymentId(), + request.getTaskParams().getModelId() ) ); return; } - rebalanceAssignments(clusterService.state(), Optional.of(params), "model deployment started", ActionListener.wrap(newMetadata -> { - TrainedModelAssignment assignment = newMetadata.getDeploymentAssignment(params.getDeploymentId()); + rebalanceAssignments(clusterService.state(), Optional.of(request), "model deployment started", ActionListener.wrap(newMetadata -> { + TrainedModelAssignment assignment = newMetadata.getDeploymentAssignment(request.getTaskParams().getDeploymentId()); if (assignment == null) { // If we could not allocate the model anywhere then it is possible the assignment // here is null. We should notify the listener of an empty assignment as the // handling of this is done elsewhere with the wait-to-start predicate. - assignment = TrainedModelAssignment.Builder.empty(params).build(); + assignment = TrainedModelAssignment.Builder.empty(request).build(); } listener.onResponse(assignment); }, listener::onFailure)); @@ -528,13 +533,13 @@ private static ClusterState forceUpdate(ClusterState currentState, TrainedModelA return ClusterState.builder(currentState).metadata(metadata).build(); } - ClusterState createModelAssignment(ClusterState currentState, StartTrainedModelDeploymentAction.TaskParams params) throws Exception { - return update(currentState, rebalanceAssignments(currentState, Optional.of(params))); + ClusterState createModelAssignment(ClusterState currentState, CreateTrainedModelAssignmentAction.Request request) throws Exception { + return update(currentState, rebalanceAssignments(currentState, Optional.of(request))); } private void rebalanceAssignments( ClusterState clusterState, - Optional modelToAdd, + Optional createAssignmentRequest, String reason, ActionListener listener ) { @@ -544,7 +549,7 @@ private void rebalanceAssignments( TrainedModelAssignmentMetadata.Builder rebalancedMetadata; try { - rebalancedMetadata = rebalanceAssignments(clusterState, modelToAdd); + rebalancedMetadata = rebalanceAssignments(clusterState, createAssignmentRequest); } catch (Exception e) { listener.onFailure(e); return; @@ -561,7 +566,7 @@ public ClusterState execute(ClusterState currentState) { currentState = stopPlatformSpecificModelsInHeterogeneousClusters( currentState, mlNodesArchitectures, - modelToAdd, + createAssignmentRequest.map(CreateTrainedModelAssignmentAction.Request::getTaskParams), clusterState ); @@ -572,7 +577,7 @@ public ClusterState execute(ClusterState currentState) { return updatedState; } - rebalanceAssignments(currentState, modelToAdd, reason, listener); + rebalanceAssignments(currentState, createAssignmentRequest, reason, listener); return currentState; } @@ -639,7 +644,7 @@ && detectNodeLoads(sourceNodes, source).equals(detectNodeLoads(targetNodes, targ private TrainedModelAssignmentMetadata.Builder rebalanceAssignments( ClusterState currentState, - Optional modelToAdd + Optional createAssignmentRequest ) throws Exception { List nodes = getAssignableNodes(currentState); logger.debug(() -> format("assignable nodes are %s", nodes.stream().map(DiscoveryNode::getId).toList())); @@ -651,7 +656,7 @@ private TrainedModelAssignmentMetadata.Builder rebalanceAssignments( currentMetadata, nodeLoads, nodeAvailabilityZoneMapper.buildMlNodesByAvailabilityZone(currentState), - modelToAdd, + createAssignmentRequest, allocatedProcessorsScale, useNewMemoryFields ); @@ -668,8 +673,12 @@ private TrainedModelAssignmentMetadata.Builder rebalanceAssignments( rebalancer.rebalance() ); - if (modelToAdd.isPresent()) { - checkModelIsFullyAllocatedIfScalingIsNotPossible(modelToAdd.get().getDeploymentId(), rebalanced, nodes); + if (createAssignmentRequest.isPresent()) { + checkModelIsFullyAllocatedIfScalingIsNotPossible( + createAssignmentRequest.get().getTaskParams().getDeploymentId(), + rebalanced, + nodes + ); } return rebalanced; @@ -795,14 +804,22 @@ private boolean isScalingPossible(List nodes) { || (smallestMLNode.isPresent() && smallestMLNode.getAsLong() < maxMLNodeSize); } - public void updateNumberOfAllocations(String deploymentId, int numberOfAllocations, ActionListener listener) { - updateNumberOfAllocations(clusterService.state(), deploymentId, numberOfAllocations, listener); + public void updateDeployment( + String deploymentId, + Integer numberOfAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, + boolean isInternal, + ActionListener listener + ) { + updateDeployment(clusterService.state(), deploymentId, numberOfAllocations, adaptiveAllocationsSettings, isInternal, listener); } - private void updateNumberOfAllocations( + private void updateDeployment( ClusterState clusterState, String deploymentId, - int numberOfAllocations, + Integer numberOfAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettingsUpdates, + boolean isInternal, ActionListener listener ) { TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.fromState(clusterState); @@ -811,7 +828,27 @@ private void updateNumberOfAllocations( listener.onFailure(ExceptionsHelper.missingModelDeployment(deploymentId)); return; } - if (existingAssignment.getTaskParams().getNumberOfAllocations() == numberOfAllocations) { + AdaptiveAllocationsSettings adaptiveAllocationsSettings = getAdaptiveAllocationsSettings( + existingAssignment.getAdaptiveAllocationsSettings(), + adaptiveAllocationsSettingsUpdates + ); + if (adaptiveAllocationsSettings != null) { + if (isInternal == false && adaptiveAllocationsSettings.getEnabled() && numberOfAllocations != null) { + ValidationException validationException = new ValidationException(); + validationException.addValidationError("[" + NUMBER_OF_ALLOCATIONS + "] cannot be set if adaptive allocations is enabled"); + listener.onFailure(validationException); + return; + } + ActionRequestValidationException validationException = adaptiveAllocationsSettings.validate(); + if (validationException != null) { + listener.onFailure(validationException); + return; + } + } + boolean hasUpdates = (numberOfAllocations != null + && Objects.equals(numberOfAllocations, existingAssignment.getTaskParams().getNumberOfAllocations()) == false) + || Objects.equals(adaptiveAllocationsSettings, existingAssignment.getAdaptiveAllocationsSettings()) == false; + if (hasUpdates == false) { listener.onResponse(existingAssignment); return; } @@ -828,7 +865,7 @@ private void updateNumberOfAllocations( if (clusterState.getMinTransportVersion().before(DISTRIBUTED_MODEL_ALLOCATION_TRANSPORT_VERSION)) { listener.onFailure( new ElasticsearchStatusException( - "cannot update number_of_allocations for deployment with model id [{}] while cluster upgrade is in progress.", + "cannot update deployment with model id [{}] while cluster upgrade is in progress.", RestStatus.CONFLICT, deploymentId ) @@ -837,7 +874,7 @@ private void updateNumberOfAllocations( } ActionListener updatedStateListener = ActionListener.wrap( - updatedState -> submitUnbatchedTask("update model deployment number_of_allocations", new ClusterStateUpdateTask() { + updatedState -> submitUnbatchedTask("update model deployment", new ClusterStateUpdateTask() { private volatile boolean isUpdated; @@ -848,7 +885,7 @@ public ClusterState execute(ClusterState currentState) { return updatedState; } logger.debug(() -> format("[%s] Retrying update as cluster state has been modified", deploymentId)); - updateNumberOfAllocations(currentState, deploymentId, numberOfAllocations, listener); + updateDeployment(currentState, deploymentId, numberOfAllocations, adaptiveAllocationsSettings, isInternal, listener); return currentState; } @@ -877,38 +914,69 @@ public void clusterStateProcessed(ClusterState oldState, ClusterState newState) listener::onFailure ); - adjustNumberOfAllocations(clusterState, existingAssignment, numberOfAllocations, updatedStateListener); + updateAssignment(clusterState, existingAssignment, numberOfAllocations, adaptiveAllocationsSettings, updatedStateListener); + } + + private AdaptiveAllocationsSettings getAdaptiveAllocationsSettings( + AdaptiveAllocationsSettings original, + AdaptiveAllocationsSettings updates + ) { + if (updates == null) { + return original; + } else if (updates == AdaptiveAllocationsSettings.RESET_PLACEHOLDER) { + return null; + } else if (original == null) { + return updates; + } else { + return original.merge(updates); + } } - private void adjustNumberOfAllocations( + private void updateAssignment( ClusterState clusterState, TrainedModelAssignment assignment, - int numberOfAllocations, + Integer numberOfAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, ActionListener listener ) { threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME).execute(() -> { - if (numberOfAllocations > assignment.getTaskParams().getNumberOfAllocations()) { - increaseNumberOfAllocations(clusterState, assignment, numberOfAllocations, listener); + if (numberOfAllocations == null || numberOfAllocations == assignment.getTaskParams().getNumberOfAllocations()) { + updateAndKeepNumberOfAllocations(clusterState, assignment, adaptiveAllocationsSettings, listener); + } else if (numberOfAllocations > assignment.getTaskParams().getNumberOfAllocations()) { + increaseNumberOfAllocations(clusterState, assignment, numberOfAllocations, adaptiveAllocationsSettings, listener); } else { - decreaseNumberOfAllocations(clusterState, assignment, numberOfAllocations, listener); + decreaseNumberOfAllocations(clusterState, assignment, numberOfAllocations, adaptiveAllocationsSettings, listener); } }); } + private void updateAndKeepNumberOfAllocations( + ClusterState clusterState, + TrainedModelAssignment assignment, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, + ActionListener listener + ) { + TrainedModelAssignment.Builder updatedAssignment = TrainedModelAssignment.Builder.fromAssignment(assignment) + .setAdaptiveAllocationsSettings(adaptiveAllocationsSettings); + TrainedModelAssignmentMetadata.Builder builder = TrainedModelAssignmentMetadata.builder(clusterState); + builder.updateAssignment(assignment.getDeploymentId(), updatedAssignment); + listener.onResponse(update(clusterState, builder)); + } + private void increaseNumberOfAllocations( ClusterState clusterState, TrainedModelAssignment assignment, int numberOfAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, ActionListener listener ) { try { + TrainedModelAssignment.Builder updatedAssignment = TrainedModelAssignment.Builder.fromAssignment(assignment) + .setNumberOfAllocations(numberOfAllocations) + .setAdaptiveAllocationsSettings(adaptiveAllocationsSettings); final ClusterState updatedClusterState = update( clusterState, - TrainedModelAssignmentMetadata.builder(clusterState) - .updateAssignment( - assignment.getDeploymentId(), - TrainedModelAssignment.Builder.fromAssignment(assignment).setNumberOfAllocations(numberOfAllocations) - ) + TrainedModelAssignmentMetadata.builder(clusterState).updateAssignment(assignment.getDeploymentId(), updatedAssignment) ); TrainedModelAssignmentMetadata.Builder rebalancedMetadata = rebalanceAssignments(updatedClusterState, Optional.empty()); if (isScalingPossible(getAssignableNodes(clusterState)) == false @@ -931,6 +999,7 @@ private void decreaseNumberOfAllocations( ClusterState clusterState, TrainedModelAssignment assignment, int numberOfAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, ActionListener listener ) { TrainedModelAssignment.Builder updatedAssignment = numberOfAllocations < assignment.totalTargetAllocations() @@ -938,7 +1007,7 @@ private void decreaseNumberOfAllocations( numberOfAllocations ) : TrainedModelAssignment.Builder.fromAssignment(assignment).setNumberOfAllocations(numberOfAllocations); - + updatedAssignment.setAdaptiveAllocationsSettings(adaptiveAllocationsSettings); // We have now reduced allocations to a number we can be sure it is satisfied // and thus we should clear the assignment reason. if (numberOfAllocations <= assignment.totalTargetAllocations()) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java index 7052e6f147b36..afd17b803cdcb 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeService.java @@ -12,8 +12,7 @@ import org.elasticsearch.ResourceNotFoundException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.search.SearchPhaseExecutionException; -import org.elasticsearch.action.support.PlainActionFuture; -import org.elasticsearch.action.support.UnsafePlainActionFuture; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.action.support.master.AcknowledgedResponse; import org.elasticsearch.cluster.ClusterChangedEvent; import org.elasticsearch.cluster.ClusterState; @@ -53,7 +52,6 @@ import org.elasticsearch.xpack.ml.inference.deployment.TrainedModelDeploymentTask; import org.elasticsearch.xpack.ml.task.AbstractJobPersistentTasksExecutor; -import java.util.ArrayDeque; import java.util.ArrayList; import java.util.Collections; import java.util.Deque; @@ -154,16 +152,29 @@ public void beforeStop() { this.expressionResolver = expressionResolver; } - public void start() { + void start() { stopped = false; - scheduledFuture = threadPool.scheduleWithFixedDelay( - this::loadQueuedModels, - MODEL_LOADING_CHECK_INTERVAL, - threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME) - ); + schedule(false); } - public void stop() { + private void schedule(boolean runImmediately) { + if (stopped) { + // do not schedule when stopped + return; + } + + var rescheduleListener = ActionListener.wrap(this::schedule, e -> this.schedule(false)); + Runnable loadQueuedModels = () -> loadQueuedModels(rescheduleListener); + var executor = threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME); + + if (runImmediately) { + executor.execute(loadQueuedModels); + } else { + scheduledFuture = threadPool.schedule(loadQueuedModels, MODEL_LOADING_CHECK_INTERVAL, executor); + } + } + + void stop() { stopped = true; ThreadPool.Cancellable cancellable = this.scheduledFuture; if (cancellable != null) { @@ -171,9 +182,9 @@ public void stop() { } } - void loadQueuedModels() { - TrainedModelDeploymentTask loadingTask; - if (loadingModels.isEmpty()) { + void loadQueuedModels(ActionListener rescheduleImmediately) { + if (stopped) { + rescheduleImmediately.onResponse(false); return; } if (latestState != null) { @@ -188,39 +199,49 @@ void loadQueuedModels() { ); if (unassignedIndices.size() > 0) { logger.trace("not loading models as indices {} primary shards are unassigned", unassignedIndices); + rescheduleImmediately.onResponse(false); return; } } - logger.trace("attempting to load all currently queued models"); - // NOTE: As soon as this method exits, the timer for the scheduler starts ticking - Deque loadingToRetry = new ArrayDeque<>(); - while ((loadingTask = loadingModels.poll()) != null) { - final String deploymentId = loadingTask.getDeploymentId(); - if (loadingTask.isStopped()) { - if (logger.isTraceEnabled()) { - String reason = loadingTask.stoppedReason().orElse("_unknown_"); - logger.trace("[{}] attempted to load stopped task with reason [{}]", deploymentId, reason); - } - continue; + + var loadingTask = loadingModels.poll(); + if (loadingTask == null) { + rescheduleImmediately.onResponse(false); + return; + } + + loadModel(loadingTask, ActionListener.wrap(retry -> { + if (retry != null && retry) { + loadingModels.offer(loadingTask); + // don't reschedule immediately if the next task is the one we just queued, instead wait a bit to retry + rescheduleImmediately.onResponse(loadingModels.peek() != loadingTask); + } else { + rescheduleImmediately.onResponse(loadingModels.isEmpty() == false); } - if (stopped) { - return; + }, e -> rescheduleImmediately.onResponse(loadingModels.isEmpty() == false))); + } + + void loadModel(TrainedModelDeploymentTask loadingTask, ActionListener retryListener) { + if (loadingTask.isStopped()) { + if (logger.isTraceEnabled()) { + logger.trace( + "[{}] attempted to load stopped task with reason [{}]", + loadingTask.getDeploymentId(), + loadingTask.stoppedReason().orElse("_unknown_") + ); } - final PlainActionFuture listener = new UnsafePlainActionFuture<>( - MachineLearning.UTILITY_THREAD_POOL_NAME - ); - try { - deploymentManager.startDeployment(loadingTask, listener); - // This needs to be synchronous here in the utility thread to keep queueing order - TrainedModelDeploymentTask deployedTask = listener.actionGet(); - // kicks off asynchronous cluster state update - handleLoadSuccess(deployedTask); - } catch (Exception ex) { + retryListener.onResponse(false); + return; + } + SubscribableListener.newForked(l -> deploymentManager.startDeployment(loadingTask, l)) + .andThen(threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME), threadPool.getThreadContext(), this::handleLoadSuccess) + .addListener(retryListener.delegateResponse((retryL, ex) -> { + var deploymentId = loadingTask.getDeploymentId(); logger.warn(() -> "[" + deploymentId + "] Start deployment failed", ex); if (ExceptionsHelper.unwrapCause(ex) instanceof ResourceNotFoundException) { - String modelId = loadingTask.getParams().getModelId(); + var modelId = loadingTask.getParams().getModelId(); logger.debug(() -> "[" + deploymentId + "] Start deployment failed as model [" + modelId + "] was not found", ex); - handleLoadFailure(loadingTask, ExceptionsHelper.missingTrainedModel(modelId, ex)); + handleLoadFailure(loadingTask, ExceptionsHelper.missingTrainedModel(modelId, ex), retryL); } else if (ExceptionsHelper.unwrapCause(ex) instanceof SearchPhaseExecutionException) { /* * This case will not catch the ElasticsearchException generated from the ChunkedTrainedModelRestorer in a scenario @@ -232,13 +253,11 @@ void loadQueuedModels() { // A search phase execution failure should be retried, push task back to the queue // This will cause the entire model to be reloaded (all the chunks) - loadingToRetry.add(loadingTask); + retryL.onResponse(true); } else { - handleLoadFailure(loadingTask, ex); + handleLoadFailure(loadingTask, ex, retryL); } - } - } - loadingModels.addAll(loadingToRetry); + }), threadPool.executor(MachineLearning.UTILITY_THREAD_POOL_NAME), threadPool.getThreadContext()); } public void gracefullyStopDeploymentAndNotify( @@ -680,14 +699,14 @@ void prepareModelToLoad(StartTrainedModelDeploymentAction.TaskParams taskParams) ); // threadsafe check to verify we are not loading/loaded the model if (deploymentIdToTask.putIfAbsent(taskParams.getDeploymentId(), task) == null) { - loadingModels.add(task); + loadingModels.offer(task); } else { // If there is already a task for the deployment, unregister the new task taskManager.unregister(task); } } - private void handleLoadSuccess(TrainedModelDeploymentTask task) { + private void handleLoadSuccess(ActionListener retryListener, TrainedModelDeploymentTask task) { logger.debug( () -> "[" + task.getParams().getDeploymentId() @@ -704,13 +723,16 @@ private void handleLoadSuccess(TrainedModelDeploymentTask task) { task.stoppedReason().orElse("_unknown_") ) ); + retryListener.onResponse(false); return; } updateStoredState( task.getDeploymentId(), RoutingInfoUpdate.updateStateAndReason(new RoutingStateAndReason(RoutingState.STARTED, "")), - ActionListener.wrap(r -> logger.debug(() -> "[" + task.getDeploymentId() + "] model loaded and accepting routes"), e -> { + ActionListener.runAfter(ActionListener.wrap(r -> { + logger.debug(() -> "[" + task.getDeploymentId() + "] model loaded and accepting routes"); + }, e -> { // This means that either the assignment has been deleted, or this node's particular route has been removed if (ExceptionsHelper.unwrapCause(e) instanceof ResourceNotFoundException) { logger.debug( @@ -732,7 +754,7 @@ private void handleLoadSuccess(TrainedModelDeploymentTask task) { e ); } - }) + }), () -> retryListener.onResponse(false)) ); } @@ -752,7 +774,7 @@ private void updateStoredState(String deploymentId, RoutingInfoUpdate update, Ac ); } - private void handleLoadFailure(TrainedModelDeploymentTask task, Exception ex) { + private void handleLoadFailure(TrainedModelDeploymentTask task, Exception ex, ActionListener retryListener) { logger.error(() -> "[" + task.getDeploymentId() + "] model [" + task.getParams().getModelId() + "] failed to load", ex); if (task.isStopped()) { logger.debug( @@ -769,14 +791,14 @@ private void handleLoadFailure(TrainedModelDeploymentTask task, Exception ex) { Runnable stopTask = () -> stopDeploymentAsync( task, "model failed to load; reason [" + ex.getMessage() + "]", - ActionListener.noop() + ActionListener.running(() -> retryListener.onResponse(false)) ); updateStoredState( task.getDeploymentId(), RoutingInfoUpdate.updateStateAndReason( new RoutingStateAndReason(RoutingState.FAILED, ExceptionsHelper.unwrapCause(ex).getMessage()) ), - ActionListener.wrap(r -> stopTask.run(), e -> stopTask.run()) + ActionListener.running(stopTask) ); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java index ef8af6af445fb..624ef5434e2a0 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancer.java @@ -14,6 +14,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.common.Strings; import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.assignment.Priority; import org.elasticsearch.xpack.core.ml.inference.assignment.RoutingInfo; @@ -50,7 +51,7 @@ class TrainedModelAssignmentRebalancer { private final TrainedModelAssignmentMetadata currentMetadata; private final Map nodeLoads; private final Map, Collection> mlNodesByZone; - private final Optional deploymentToAdd; + private final Optional createAssignmentRequest; private final int allocatedProcessorsScale; private final boolean useNewMemoryFields; @@ -59,28 +60,29 @@ class TrainedModelAssignmentRebalancer { TrainedModelAssignmentMetadata currentMetadata, Map nodeLoads, Map, Collection> mlNodesByZone, - Optional deploymentToAdd, + Optional createAssignmentRequest, int allocatedProcessorsScale, boolean useNewMemoryFields ) { this.currentMetadata = Objects.requireNonNull(currentMetadata); this.nodeLoads = Objects.requireNonNull(nodeLoads); this.mlNodesByZone = Objects.requireNonNull(mlNodesByZone); - this.deploymentToAdd = Objects.requireNonNull(deploymentToAdd); + this.createAssignmentRequest = Objects.requireNonNull(createAssignmentRequest); this.allocatedProcessorsScale = allocatedProcessorsScale; this.useNewMemoryFields = useNewMemoryFields; } TrainedModelAssignmentMetadata.Builder rebalance() { - if (deploymentToAdd.isPresent() && currentMetadata.hasDeployment(deploymentToAdd.get().getDeploymentId())) { + if (createAssignmentRequest.isPresent() + && currentMetadata.hasDeployment(createAssignmentRequest.get().getTaskParams().getDeploymentId())) { throw new ResourceAlreadyExistsException( "[{}] assignment for deployment with model [{}] already exists", - deploymentToAdd.get().getDeploymentId(), - deploymentToAdd.get().getModelId() + createAssignmentRequest.get().getTaskParams().getDeploymentId(), + createAssignmentRequest.get().getTaskParams().getModelId() ); } - if (deploymentToAdd.isEmpty() && areAllModelsSatisfiedAndNoOutdatedRoutingEntries()) { + if (createAssignmentRequest.isEmpty() && areAllModelsSatisfiedAndNoOutdatedRoutingEntries()) { logger.trace(() -> "No need to rebalance as all model deployments are satisfied"); return TrainedModelAssignmentMetadata.Builder.fromMetadata(currentMetadata); } @@ -176,14 +178,15 @@ private AssignmentPlan computePlanForNormalPriorityModels( assignment.getTaskParams().getThreadsPerAllocation(), currentAssignments, assignment.getMaxAssignedAllocations(), + assignment.getAdaptiveAllocationsSettings(), // in the mixed cluster state use old memory fields to avoid unstable assignment plans useNewMemoryFields ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0, useNewMemoryFields ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0 ); }) .forEach(planDeployments::add); - if (deploymentToAdd.isPresent() && deploymentToAdd.get().getPriority() != Priority.LOW) { - StartTrainedModelDeploymentAction.TaskParams taskParams = deploymentToAdd.get(); + if (createAssignmentRequest.isPresent() && createAssignmentRequest.get().getTaskParams().getPriority() != Priority.LOW) { + StartTrainedModelDeploymentAction.TaskParams taskParams = createAssignmentRequest.get().getTaskParams(); planDeployments.add( new AssignmentPlan.Deployment( taskParams.getDeploymentId(), @@ -192,6 +195,7 @@ private AssignmentPlan computePlanForNormalPriorityModels( taskParams.getThreadsPerAllocation(), Map.of(), 0, + createAssignmentRequest.get().getAdaptiveAllocationsSettings(), // in the mixed cluster state use old memory fields to avoid unstable assignment plans useNewMemoryFields ? taskParams.getPerDeploymentMemoryBytes() : 0, useNewMemoryFields ? taskParams.getPerAllocationMemoryBytes() : 0 @@ -231,14 +235,15 @@ private AssignmentPlan computePlanForLowPriorityModels(Set assignableNod assignment.getTaskParams().getThreadsPerAllocation(), findFittingAssignments(assignment, assignableNodeIds, remainingNodeMemory), assignment.getMaxAssignedAllocations(), + assignment.getAdaptiveAllocationsSettings(), Priority.LOW, (useNewMemoryFields == false) ? assignment.getTaskParams().getPerDeploymentMemoryBytes() : 0, (useNewMemoryFields == false) ? assignment.getTaskParams().getPerAllocationMemoryBytes() : 0 ) ) .forEach(planDeployments::add); - if (deploymentToAdd.isPresent() && deploymentToAdd.get().getPriority() == Priority.LOW) { - StartTrainedModelDeploymentAction.TaskParams taskParams = deploymentToAdd.get(); + if (createAssignmentRequest.isPresent() && createAssignmentRequest.get().getTaskParams().getPriority() == Priority.LOW) { + StartTrainedModelDeploymentAction.TaskParams taskParams = createAssignmentRequest.get().getTaskParams(); planDeployments.add( new AssignmentPlan.Deployment( taskParams.getDeploymentId(), @@ -247,6 +252,7 @@ private AssignmentPlan computePlanForLowPriorityModels(Set assignableNod taskParams.getThreadsPerAllocation(), Map.of(), 0, + createAssignmentRequest.get().getAdaptiveAllocationsSettings(), Priority.LOW, (useNewMemoryFields == false) ? taskParams.getPerDeploymentMemoryBytes() : 0, (useNewMemoryFields == false) ? taskParams.getPerAllocationMemoryBytes() : 0 @@ -325,11 +331,12 @@ private TrainedModelAssignmentMetadata.Builder buildAssignmentsFromPlan(Assignme for (AssignmentPlan.Deployment deployment : assignmentPlan.models()) { TrainedModelAssignment existingAssignment = currentMetadata.getDeploymentAssignment(deployment.id()); - TrainedModelAssignment.Builder assignmentBuilder = TrainedModelAssignment.Builder.empty( - existingAssignment == null && deploymentToAdd.isPresent() - ? deploymentToAdd.get() - : currentMetadata.getDeploymentAssignment(deployment.id()).getTaskParams() - ); + TrainedModelAssignment.Builder assignmentBuilder = existingAssignment == null && createAssignmentRequest.isPresent() + ? TrainedModelAssignment.Builder.empty(createAssignmentRequest.get()) + : TrainedModelAssignment.Builder.empty( + currentMetadata.getDeploymentAssignment(deployment.id()).getTaskParams(), + currentMetadata.getDeploymentAssignment(deployment.id()).getAdaptiveAllocationsSettings() + ); if (existingAssignment != null) { assignmentBuilder.setStartTime(existingAssignment.getStartTime()); assignmentBuilder.setMaxAssignedAllocations(existingAssignment.getMaxAssignedAllocations()); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentService.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentService.java index 0609e0e6ff916..bf19b505e5cfe 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentService.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentService.java @@ -30,7 +30,6 @@ import org.elasticsearch.transport.ConnectTransportException; import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.DeleteTrainedModelAssignmentAction; -import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata; @@ -85,10 +84,10 @@ public void updateModelAssignmentState( } public void createNewModelAssignment( - StartTrainedModelDeploymentAction.TaskParams taskParams, + CreateTrainedModelAssignmentAction.Request request, ActionListener listener ) { - client.execute(CreateTrainedModelAssignmentAction.INSTANCE, new CreateTrainedModelAssignmentAction.Request(taskParams), listener); + client.execute(CreateTrainedModelAssignmentAction.INSTANCE, request, listener); } public void deleteModelAssignment(String modelId, ActionListener listener) { diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java index 98988ffa11055..0151c8f5ee9c8 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AbstractPreserveAllocations.java @@ -60,6 +60,7 @@ Deployment modifyModelPreservingPreviousAssignments(Deployment m) { m.threadsPerAllocation(), calculateAllocationsPerNodeToPreserve(m), m.maxAssignedAllocations(), + m.getAdaptiveAllocationsSettings(), m.perDeploymentMemoryBytes(), m.perAllocationMemoryBytes() ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java index 123c728587604..7fc16394ed85c 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlan.java @@ -11,6 +11,7 @@ import org.elasticsearch.common.util.Maps; import org.elasticsearch.core.Tuple; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; import org.elasticsearch.xpack.core.ml.inference.assignment.Priority; import java.util.ArrayList; @@ -37,11 +38,11 @@ public record Deployment( int threadsPerAllocation, Map currentAllocationsByNodeId, int maxAssignedAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, Priority priority, long perDeploymentMemoryBytes, long perAllocationMemoryBytes ) { - public Deployment( String id, long modelBytes, @@ -49,6 +50,7 @@ public Deployment( int threadsPerAllocation, Map currentAllocationsByNodeId, int maxAssignedAllocations, + AdaptiveAllocationsSettings adaptiveAllocationsSettings, long perDeploymentMemoryBytes, long perAllocationMemoryBytes ) { @@ -59,12 +61,17 @@ public Deployment( threadsPerAllocation, currentAllocationsByNodeId, maxAssignedAllocations, + adaptiveAllocationsSettings, Priority.NORMAL, perDeploymentMemoryBytes, perAllocationMemoryBytes ); } + public AdaptiveAllocationsSettings getAdaptiveAllocationsSettings() { + return adaptiveAllocationsSettings; + } + int getCurrentAssignedAllocations() { return currentAllocationsByNodeId.values().stream().mapToInt(Integer::intValue).sum(); } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java index b1c017b1a784c..38279a2fd6c03 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanner.java @@ -118,6 +118,7 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat // don't rely on the current allocation new HashMap<>(), m.maxAssignedAllocations(), + m.getAdaptiveAllocationsSettings(), m.perDeploymentMemoryBytes(), m.perAllocationMemoryBytes() ) @@ -149,6 +150,7 @@ private AssignmentPlan solveAllocatingAtLeastOnceModelsThatWerePreviouslyAllocat m.threadsPerAllocation(), currentAllocationsByNodeId, m.maxAssignedAllocations(), + m.getAdaptiveAllocationsSettings(), m.perDeploymentMemoryBytes(), m.perAllocationMemoryBytes() ); diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java index 9af2e4cd49b17..1f0857391598f 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlanner.java @@ -129,6 +129,7 @@ private AssignmentPlan computeZonePlan( (tryAssigningPreviouslyAssignedModels && modelIdToRemainingAllocations.get(m.id()) == m.allocations()) ? m.maxAssignedAllocations() : 0, + m.getAdaptiveAllocationsSettings(), // Only force assigning at least once previously assigned models that have not had any allocation yet m.perDeploymentMemoryBytes(), m.perAllocationMemoryBytes() @@ -154,6 +155,7 @@ private AssignmentPlan computePlanAcrossAllNodes(List plans) { m.threadsPerAllocation(), allocationsByNodeIdByModelId.get(m.id()), m.maxAssignedAllocations(), + m.getAdaptiveAllocationsSettings(), m.perDeploymentMemoryBytes(), m.perAllocationMemoryBytes() ) diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java index 87fad19ab87fc..1bb2f1006822e 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessor.java @@ -160,11 +160,11 @@ void processInferenceResult(PyTorchResult result) { } logger.debug(() -> format("[%s] Parsed inference result with id [%s]", modelId, result.requestId())); - updateStats(timeMs, Boolean.TRUE.equals(result.isCacheHit())); PendingResult pendingResult = pendingResults.remove(result.requestId()); if (pendingResult == null) { logger.debug(() -> format("[%s] no pending result for inference [%s]", modelId, result.requestId())); } else { + updateStats(timeMs, Boolean.TRUE.equals(result.isCacheHit())); pendingResult.listener.onResponse(result); } } diff --git a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java index 1a9fc6ce99823..e308eb6007973 100644 --- a/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java +++ b/x-pack/plugin/ml/src/main/java/org/elasticsearch/xpack/ml/rest/inference/RestStartTrainedModelDeploymentAction.java @@ -94,7 +94,8 @@ protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient NUMBER_OF_ALLOCATIONS.getPreferredName(), RestApiVersion.V_8, restRequest, - (r, s) -> r.paramAsInt(s, request.getNumberOfAllocations()), + // This is to propagate a null value, which paramAsInt does not support. + (r, s) -> r.hasParam(s) ? (Integer) r.paramAsInt(s, 0) : request.getNumberOfAllocations(), request::setNumberOfAllocations ); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java index 084a9d95939c5..afa372fb94527 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MachineLearningInfoTransportActionTests.java @@ -1015,6 +1015,7 @@ private Map setupComplexMocks() { null, null, null, + null, Instant.now(), List.of( AssignmentStats.NodeStats.forStartedState( @@ -1064,6 +1065,7 @@ private Map setupComplexMocks() { "model_4", 2, 2, + null, 1000, ByteSizeValue.ofBytes(1000), Instant.now(), diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlInitializationServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlInitializationServiceTests.java index 2f30d131021b4..2f251e3b0aee6 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlInitializationServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlInitializationServiceTests.java @@ -13,11 +13,14 @@ import org.elasticsearch.client.internal.Client; import org.elasticsearch.client.internal.IndicesAdminClient; import org.elasticsearch.cluster.ClusterName; +import org.elasticsearch.cluster.ClusterState; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.ml.inference.adaptiveallocations.AdaptiveAllocationsScalerService; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; import org.junit.Before; import java.util.Map; @@ -36,6 +39,7 @@ public class MlInitializationServiceTests extends ESTestCase { private ThreadPool threadPool; private ClusterService clusterService; private Client client; + private InferenceAuditor inferenceAuditor; private MlAssignmentNotifier mlAssignmentNotifier; @Before @@ -44,9 +48,11 @@ public void setUpMocks() { threadPool = deterministicTaskQueue.getThreadPool(); clusterService = mock(ClusterService.class); client = mock(Client.class); + inferenceAuditor = mock(InferenceAuditor.class); mlAssignmentNotifier = mock(MlAssignmentNotifier.class); when(clusterService.getClusterName()).thenReturn(CLUSTER_NAME); + when(clusterService.state()).thenReturn(ClusterState.EMPTY_STATE); @SuppressWarnings("unchecked") ActionFuture getSettingsResponseActionFuture = mock(ActionFuture.class); @@ -68,6 +74,7 @@ public void testInitialize() { threadPool, clusterService, client, + inferenceAuditor, mlAssignmentNotifier, true, true, @@ -83,6 +90,7 @@ public void testInitialize_noMasterNode() { threadPool, clusterService, client, + inferenceAuditor, mlAssignmentNotifier, true, true, @@ -94,11 +102,13 @@ public void testInitialize_noMasterNode() { public void testNodeGoesFromMasterToNonMasterAndBack() { MlDailyMaintenanceService initialDailyMaintenanceService = mock(MlDailyMaintenanceService.class); + AdaptiveAllocationsScalerService adaptiveAllocationsScalerService = mock(AdaptiveAllocationsScalerService.class); MlInitializationService initializationService = new MlInitializationService( client, threadPool, initialDailyMaintenanceService, + adaptiveAllocationsScalerService, clusterService ); initializationService.offMaster(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlLifeCycleServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlLifeCycleServiceTests.java index 2b206de4cf42f..bdabb42ecd467 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlLifeCycleServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlLifeCycleServiceTests.java @@ -191,7 +191,7 @@ public void testIsNodeSafeToShutdownReturnsFalseWhenStartingDeploymentExists() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( "1", - TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom()) + TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom(), null) .addRoutingEntry(nodeId, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -215,12 +215,12 @@ public void testIsNodeSafeToShutdownReturnsFalseWhenStoppingAndStoppedDeployment TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( "1", - TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom()) + TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom(), null) .addRoutingEntry(nodeId, new RoutingInfo(1, 1, RoutingState.STOPPED, "")) ) .addNewAssignment( "2", - TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom()) + TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom(), null) .addRoutingEntry(nodeId, new RoutingInfo(1, 1, RoutingState.STOPPING, "")) ) .build() @@ -244,12 +244,12 @@ public void testIsNodeSafeToShutdownReturnsTrueWhenStoppedDeploymentsExist() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( "1", - TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom()) + TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom(), null) .addRoutingEntry(nodeId, new RoutingInfo(1, 1, RoutingState.STOPPED, "")) ) .addNewAssignment( "2", - TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom()) + TrainedModelAssignment.Builder.empty(StartTrainedModelDeploymentTaskParamsTests.createRandom(), null) .addRoutingEntry(nodeId, new RoutingInfo(1, 1, RoutingState.STOPPED, "")) ) .build() diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetricsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetricsTests.java index 2262c21070e75..5fb1381b881ea 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetricsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/MlMetricsTests.java @@ -132,18 +132,18 @@ public void testFindTrainedModelAllocationCounts() { TrainedModelAssignmentMetadata.Builder metadataBuilder = TrainedModelAssignmentMetadata.Builder.empty(); metadataBuilder.addNewAssignment( "model1", - TrainedModelAssignment.Builder.empty(mock(StartTrainedModelDeploymentAction.TaskParams.class)) + TrainedModelAssignment.Builder.empty(mock(StartTrainedModelDeploymentAction.TaskParams.class), null) .addRoutingEntry("node1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry("node2", new RoutingInfo(0, 1, RoutingState.FAILED, "")) ); metadataBuilder.addNewAssignment( "model2", - TrainedModelAssignment.Builder.empty(mock(StartTrainedModelDeploymentAction.TaskParams.class)) + TrainedModelAssignment.Builder.empty(mock(StartTrainedModelDeploymentAction.TaskParams.class), null) .addRoutingEntry("node1", new RoutingInfo(2, 2, RoutingState.STARTED, "")) ); metadataBuilder.addNewAssignment( "model3", - TrainedModelAssignment.Builder.empty(mock(StartTrainedModelDeploymentAction.TaskParams.class)) + TrainedModelAssignment.Builder.empty(mock(StartTrainedModelDeploymentAction.TaskParams.class), null) .addRoutingEntry("node2", new RoutingInfo(0, 1, RoutingState.STARTING, "")) ); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java index b8dd3559253ee..4a66be4a773f5 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/action/TransportGetDeploymentStatsActionTests.java @@ -83,6 +83,7 @@ public void testAddFailedRoutes_GivenMixedResponses() throws UnknownHostExceptio "deployment1", randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8), + null, randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)), Instant.now(), @@ -121,6 +122,7 @@ public void testAddFailedRoutes_TaskResultIsOverwritten() throws UnknownHostExce "deployment1", randomBoolean() ? null : randomIntBetween(1, 8), randomBoolean() ? null : randomIntBetween(1, 8), + null, randomBoolean() ? null : randomIntBetween(1, 10000), randomBoolean() ? null : ByteSizeValue.ofBytes(randomLongBetween(1, 1000000)), Instant.now(), @@ -169,7 +171,8 @@ private static TrainedModelAssignment createAssignment(String modelId) { Priority.NORMAL, 0L, 0L - ) + ), + null ).build(); } } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingResourceTrackerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingResourceTrackerTests.java index 0d91ce45c46ba..41a86e436f468 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingResourceTrackerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlAutoscalingResourceTrackerTests.java @@ -1143,7 +1143,8 @@ public void testGetMemoryAndProcessorsScaleDown() throws InterruptedException { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build(), "model-2", TrainedModelAssignment.Builder.empty( @@ -1158,7 +1159,8 @@ public void testGetMemoryAndProcessorsScaleDown() throws InterruptedException { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-3", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build() ), List.of( @@ -1242,7 +1244,8 @@ public void testGetMemoryAndProcessorsScaleDownPreventedByMinNodes() throws Inte Priority.NORMAL, 0L, 0L - ) + ), + null ) .addRoutingEntry("ml-node-1", new RoutingInfo(2, 2, RoutingState.STARTED, "")) .addRoutingEntry("ml-node-2", new RoutingInfo(2, 2, RoutingState.STARTED, "")) @@ -1260,7 +1263,8 @@ public void testGetMemoryAndProcessorsScaleDownPreventedByMinNodes() throws Inte Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-3", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build() ), List.of( @@ -1334,7 +1338,8 @@ public void testGetMemoryAndProcessorsScaleDownPreventedByDummyEntityMemory() th Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build(), "model-2", TrainedModelAssignment.Builder.empty( @@ -1349,7 +1354,8 @@ public void testGetMemoryAndProcessorsScaleDownPreventedByDummyEntityMemory() th Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-3", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build() ), List.of( @@ -1432,7 +1438,8 @@ public void testGetMemoryAndProcessorsScaleDownNotPreventedByDummyEntityProcesso Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build(), "model-2", TrainedModelAssignment.Builder.empty( @@ -1447,7 +1454,8 @@ public void testGetMemoryAndProcessorsScaleDownNotPreventedByDummyEntityProcesso Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-3", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build() ), List.of( @@ -1525,7 +1533,8 @@ public void testGetMemoryAndProcessorsScaleDownNotPreventedByDummyEntityAsMemory Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build(), "model-2", TrainedModelAssignment.Builder.empty( @@ -1540,7 +1549,8 @@ public void testGetMemoryAndProcessorsScaleDownNotPreventedByDummyEntityAsMemory Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry("ml-node-3", new RoutingInfo(1, 1, RoutingState.STARTED, "")).build() ), List.of( diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlMemoryAutoscalingDeciderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlMemoryAutoscalingDeciderTests.java index a916900b199ce..970044c188849 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlMemoryAutoscalingDeciderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlMemoryAutoscalingDeciderTests.java @@ -1069,7 +1069,8 @@ public void testCpuModelAssignmentRequirements() { Priority.NORMAL, 0L, 0L - ) + ), + null ).build(), TrainedModelAssignment.Builder.empty( new StartTrainedModelDeploymentAction.TaskParams( @@ -1083,7 +1084,8 @@ public void testCpuModelAssignmentRequirements() { Priority.NORMAL, 0L, 0L - ) + ), + null ).build() ), withMlNodes("ml_node_1", "ml_node_2"), @@ -1105,7 +1107,8 @@ public void testCpuModelAssignmentRequirements() { Priority.NORMAL, 0L, 0L - ) + ), + null ).build(), TrainedModelAssignment.Builder.empty( new StartTrainedModelDeploymentAction.TaskParams( @@ -1119,7 +1122,8 @@ public void testCpuModelAssignmentRequirements() { Priority.NORMAL, 0L, 0L - ) + ), + null ).build() ), withMlNodes("ml_node_1", "ml_node_2"), @@ -1141,7 +1145,8 @@ public void testCpuModelAssignmentRequirements() { Priority.NORMAL, 0L, 0L - ) + ), + null ).build(), TrainedModelAssignment.Builder.empty( new StartTrainedModelDeploymentAction.TaskParams( @@ -1155,7 +1160,8 @@ public void testCpuModelAssignmentRequirements() { Priority.NORMAL, 0L, 0L - ) + ), + null ).build() ), withMlNodes("ml_node_1", "ml_node_2", "ml_node_3", "ml_node_4"), diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlProcessorAutoscalingDeciderTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlProcessorAutoscalingDeciderTests.java index 97fd66e284010..ba40dc0bfdda7 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlProcessorAutoscalingDeciderTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/autoscaling/MlProcessorAutoscalingDeciderTests.java @@ -79,7 +79,8 @@ public void testScale_GivenCurrentCapacityIsUsedExactly() { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId1, new RoutingInfo(2, 2, RoutingState.STARTED, "")) ) .addNewAssignment( @@ -96,7 +97,8 @@ public void testScale_GivenCurrentCapacityIsUsedExactly() { Priority.NORMAL, 0L, 0L - ) + ), + null ) .addRoutingEntry(mlNodeId1, new RoutingInfo(2, 2, RoutingState.STARTED, "")) .addRoutingEntry(mlNodeId2, new RoutingInfo(8, 8, RoutingState.STARTED, "")) @@ -153,7 +155,8 @@ public void testScale_GivenUnsatisfiedDeployments() { Priority.NORMAL, 0L, 0L - ) + ), + null ) ) .addNewAssignment( @@ -170,7 +173,8 @@ public void testScale_GivenUnsatisfiedDeployments() { Priority.NORMAL, 0L, 0L - ) + ), + null ) .addRoutingEntry(mlNodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(mlNodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) @@ -227,7 +231,8 @@ public void testScale_GivenUnsatisfiedDeploymentIsLowPriority_ShouldNotScaleUp() Priority.LOW, 0L, 0L - ) + ), + null ) ) .addNewAssignment( @@ -244,7 +249,8 @@ public void testScale_GivenUnsatisfiedDeploymentIsLowPriority_ShouldNotScaleUp() Priority.NORMAL, 0L, 0L - ) + ), + null ) .addRoutingEntry(mlNodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(mlNodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) @@ -301,7 +307,8 @@ public void testScale_GivenMoreThanHalfProcessorsAreUsed() { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId1, new RoutingInfo(2, 2, RoutingState.STARTED, "")) ) .addNewAssignment( @@ -318,7 +325,8 @@ public void testScale_GivenMoreThanHalfProcessorsAreUsed() { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build() @@ -386,7 +394,8 @@ public void testScale_GivenDownScalePossible_DelayNotSatisfied() { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId1, new RoutingInfo(2, 2, RoutingState.STARTED, "")) ) .addNewAssignment( @@ -403,7 +412,8 @@ public void testScale_GivenDownScalePossible_DelayNotSatisfied() { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build() @@ -459,7 +469,8 @@ public void testScale_GivenDownScalePossible_DelaySatisfied() { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId1, new RoutingInfo(2, 2, RoutingState.STARTED, "")) ) .addNewAssignment( @@ -476,7 +487,8 @@ public void testScale_GivenDownScalePossible_DelaySatisfied() { Priority.NORMAL, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build() @@ -536,7 +548,8 @@ public void testScale_GivenLowPriorityDeploymentsOnly() { Priority.LOW, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( @@ -553,7 +566,8 @@ public void testScale_GivenLowPriorityDeploymentsOnly() { Priority.LOW, 0L, 0L - ) + ), + null ).addRoutingEntry(mlNodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build() diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java new file mode 100644 index 0000000000000..3ad44f256dc66 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerServiceTests.java @@ -0,0 +1,239 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.adaptiveallocations; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.client.internal.Client; +import org.elasticsearch.cluster.ClusterChangedEvent; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.cluster.node.DiscoveryNodeUtils; +import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.unit.ByteSizeValue; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.threadpool.ScalingExecutorBuilder; +import org.elasticsearch.threadpool.TestThreadPool; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; +import org.elasticsearch.xpack.core.ml.action.GetDeploymentStatsAction; +import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelDeploymentAction; +import org.elasticsearch.xpack.core.ml.inference.assignment.AdaptiveAllocationsSettings; +import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentStats; +import org.elasticsearch.xpack.core.ml.inference.assignment.Priority; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignment; +import org.elasticsearch.xpack.core.ml.inference.assignment.TrainedModelAssignmentMetadata; +import org.elasticsearch.xpack.ml.MachineLearning; +import org.elasticsearch.xpack.ml.notifications.InferenceAuditor; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.time.Instant; +import java.util.List; +import java.util.Map; + +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class AdaptiveAllocationsScalerServiceTests extends ESTestCase { + + private TestThreadPool threadPool; + private ClusterService clusterService; + private Client client; + private InferenceAuditor inferenceAuditor; + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + threadPool = createThreadPool( + new ScalingExecutorBuilder(MachineLearning.UTILITY_THREAD_POOL_NAME, 0, 1, TimeValue.timeValueMinutes(10), false) + ); + clusterService = mock(ClusterService.class); + client = mock(Client.class); + inferenceAuditor = mock(InferenceAuditor.class); + } + + @Override + @After + public void tearDown() throws Exception { + this.threadPool.close(); + super.tearDown(); + } + + private ClusterState getClusterState(int numAllocations) { + ClusterState clusterState = mock(ClusterState.class); + Metadata metadata = mock(Metadata.class); + when(clusterState.getMetadata()).thenReturn(metadata); + when(metadata.custom("trained_model_assignment")).thenReturn( + new TrainedModelAssignmentMetadata( + Map.of( + "test-deployment", + TrainedModelAssignment.Builder.empty( + new StartTrainedModelDeploymentAction.TaskParams( + "model-id", + "test-deployment", + 100_000_000, + numAllocations, + 1, + 1024, + ByteSizeValue.ZERO, + Priority.NORMAL, + 100_000_000, + 100_000_000 + ), + new AdaptiveAllocationsSettings(true, null, null) + ).build() + ) + ) + ); + return clusterState; + } + + private GetDeploymentStatsAction.Response getDeploymentStatsResponse(int numAllocations, int inferenceCount, double latency) { + return new GetDeploymentStatsAction.Response( + List.of(), + List.of(), + List.of( + new AssignmentStats( + "test-deployment", + "model-id", + 1, + numAllocations, + new AdaptiveAllocationsSettings(true, null, null), + 1024, + ByteSizeValue.ZERO, + Instant.now(), + List.of( + AssignmentStats.NodeStats.forStartedState( + DiscoveryNodeUtils.create("node_1"), + inferenceCount, + latency, + latency, + 0, + 0, + 0, + 0, + 0, + Instant.now(), + Instant.now(), + 1, + numAllocations, + inferenceCount, + inferenceCount, + latency, + 0 + ) + ), + Priority.NORMAL + ) + ), + 0 + ); + } + + public void test() throws IOException { + // Initialize the cluster with a deployment with 1 allocation. + ClusterState clusterState = getClusterState(1); + when(clusterService.state()).thenReturn(clusterState); + + AdaptiveAllocationsScalerService service = new AdaptiveAllocationsScalerService( + threadPool, + clusterService, + client, + inferenceAuditor, + true, + 1 + ); + service.start(); + + verify(clusterService).state(); + verify(clusterService).addListener(same(service)); + verifyNoMoreInteractions(client, clusterService); + reset(client, clusterService); + + // First cycle: 1 inference request, so no need for scaling. + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(getDeploymentStatsResponse(1, 1, 11.0)); + return Void.TYPE; + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); + + safeSleep(1200); + + verify(client, times(1)).threadPool(); + verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); + verifyNoMoreInteractions(client, clusterService); + reset(client, clusterService); + + // Second cycle: 150 inference request with a latency of 10ms, so scale up to 2 allocations. + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(getDeploymentStatsResponse(1, 150, 10.0)); + return Void.TYPE; + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(null); + return Void.TYPE; + }).when(client).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), any(), any()); + + safeSleep(1000); + + verify(client, times(2)).threadPool(); + verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); + var updateRequest = new UpdateTrainedModelDeploymentAction.Request("test-deployment"); + updateRequest.setNumberOfAllocations(2); + updateRequest.setIsInternal(true); + verify(client, times(1)).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), eq(updateRequest), any()); + verifyNoMoreInteractions(client, clusterService); + reset(client, clusterService); + + clusterState = getClusterState(2); + ClusterChangedEvent clusterChangedEvent = mock(ClusterChangedEvent.class); + when(clusterChangedEvent.state()).thenReturn(clusterState); + service.clusterChanged(clusterChangedEvent); + + // Third cycle: 0 inference requests, but keep 2 allocations, because of cooldown. + when(client.threadPool()).thenReturn(threadPool); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(getDeploymentStatsResponse(2, 0, 9.0)); + return Void.TYPE; + }).when(client).execute(eq(GetDeploymentStatsAction.INSTANCE), eq(new GetDeploymentStatsAction.Request("test-deployment")), any()); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocationOnMock.getArguments()[2]; + listener.onResponse(null); + return Void.TYPE; + }).when(client).execute(eq(UpdateTrainedModelDeploymentAction.INSTANCE), any(), any()); + + safeSleep(1000); + + verify(client, times(1)).threadPool(); + verify(client, times(1)).execute(eq(GetDeploymentStatsAction.INSTANCE), any(), any()); + verifyNoMoreInteractions(client, clusterService); + + service.stop(); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerTests.java new file mode 100644 index 0000000000000..9758d00627efe --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/AdaptiveAllocationsScalerTests.java @@ -0,0 +1,141 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.adaptiveallocations; + +import org.elasticsearch.test.ESTestCase; + +import java.util.Random; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; +import static org.hamcrest.Matchers.nullValue; + +public class AdaptiveAllocationsScalerTests extends ESTestCase { + + public void testAutoscaling_scaleUpAndDown() { + AdaptiveAllocationsScaler adaptiveAllocationsScaler = new AdaptiveAllocationsScaler("test-deployment", 1); + + // With 1 allocation the system can handle 500 requests * 0.020 sec/request. + // To handle remaining requests the system should scale to 2 allocations. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(500, 100, 100, 0.020), 10, 1); + assertThat(adaptiveAllocationsScaler.scale(), equalTo(2)); + + // With 2 allocation the system can handle 800 requests * 0.025 sec/request. + // To handle remaining requests the system should scale to 3 allocations. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(800, 100, 50, 0.025), 10, 2); + assertThat(adaptiveAllocationsScaler.scale(), equalTo(3)); + + // With 3 allocations the system can handle the load. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(1000, 0, 0, 0.025), 10, 3); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + + // No load anymore, so the system should gradually scale down to 1 allocation. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, Double.NaN), 10, 3); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, Double.NaN), 10, 3); + assertThat(adaptiveAllocationsScaler.scale(), equalTo(2)); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, Double.NaN), 10, 2); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, Double.NaN), 10, 2); + assertThat(adaptiveAllocationsScaler.scale(), equalTo(1)); + } + + public void testAutoscaling_noOscillating() { + AdaptiveAllocationsScaler adaptiveAllocationsScaler = new AdaptiveAllocationsScaler("test-deployment", 1); + + // With 1 allocation the system can handle 880 requests * 0.010 sec/request. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(880, 0, 0, 0.010), 10, 1); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(880, 0, 0, 0.010), 10, 1); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + + // Increase the load to 980 requests * 0.010 sec/request, and the system + // should scale to 2 allocations to have some spare capacity. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(920, 0, 0, 0.010), 10, 1); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(950, 0, 0, 0.010), 10, 1); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(980, 0, 0, 0.010), 10, 1); + assertThat(adaptiveAllocationsScaler.scale(), equalTo(2)); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(980, 0, 0, 0.010), 10, 2); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + + // Reducing the load to just 880 requests * 0.010 sec/request should not + // trigger scaling down again, to prevent oscillating. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(880, 0, 0, 0.010), 10, 2); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(880, 0, 0, 0.010), 10, 2); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + } + + public void testAutoscaling_respectMinMaxAllocations() { + AdaptiveAllocationsScaler adaptiveAllocationsScaler = new AdaptiveAllocationsScaler("test-deployment", 1); + adaptiveAllocationsScaler.setMinMaxNumberOfAllocations(2, 5); + + // Even though there are no requests, scale to the minimum of 2 allocations. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, 0.010), 10, 1); + assertThat(adaptiveAllocationsScaler.scale(), equalTo(2)); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, 0.010), 10, 2); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + + // Even though there are many requests, the scale to the maximum of 5 allocations. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(100, 10000, 1000, 0.010), 10, 2); + assertThat(adaptiveAllocationsScaler.scale(), equalTo(5)); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(500, 10000, 1000, 0.010), 10, 5); + assertThat(adaptiveAllocationsScaler.scale(), nullValue()); + + // After a while of no requests, scale to the minimum of 2 allocations. + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, 0.010), 10, 5); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, 0.010), 10, 5); + adaptiveAllocationsScaler.process(new AdaptiveAllocationsScalerService.Stats(0, 0, 0, 0.010), 10, 5); + assertThat(adaptiveAllocationsScaler.scale(), equalTo(2)); + } + + public void testEstimation_highVariance() { + AdaptiveAllocationsScaler adaptiveAllocationsScaler = new AdaptiveAllocationsScaler("test-deployment", 1); + + Random random = new Random(42); + + double averageLoadMean = 0.0; + double averageLoadError = 0.0; + + double time = 0.0; + for (int nextMeasurementTime = 1; nextMeasurementTime <= 100; nextMeasurementTime++) { + // Sample one second of data (until the next measurement time). + // This contains approximately 100 requests with high-variance inference times. + AdaptiveAllocationsScalerService.Stats stats = new AdaptiveAllocationsScalerService.Stats(0, 0, 0, 0); + while (time < nextMeasurementTime) { + // Draw inference times from a log-normal distribution, which has high variance. + // This distribution approximately has: mean=3.40, variance=98.4. + double inferenceTime = Math.exp(random.nextGaussian(0.1, 1.5)); + stats = stats.add(new AdaptiveAllocationsScalerService.Stats(1, 0, 0, inferenceTime)); + + // The requests are Poisson distributed, which means the time inbetween + // requests follows an exponential distribution. + // This distribution has on average 100 requests per second. + double dt = 0.01 * random.nextExponential(); + time += dt; + } + + adaptiveAllocationsScaler.process(stats, 1, 1); + double lower = adaptiveAllocationsScaler.getLoadLower(); + double upper = adaptiveAllocationsScaler.getLoadUpper(); + averageLoadMean += (upper + lower) / 2.0; + averageLoadError += (upper - lower) / 2.0; + } + + averageLoadMean /= 100; + averageLoadError /= 100; + + double expectedLoad = 100 * 3.40; + assertThat(averageLoadMean - averageLoadError, lessThan(expectedLoad)); + assertThat(averageLoadMean + averageLoadError, greaterThan(expectedLoad)); + assertThat(averageLoadError / averageLoadMean, lessThan(1 - AdaptiveAllocationsScaler.SCALE_UP_THRESHOLD)); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/KalmanFilter1dTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/KalmanFilter1dTests.java new file mode 100644 index 0000000000000..f9b3a8966b627 --- /dev/null +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/adaptiveallocations/KalmanFilter1dTests.java @@ -0,0 +1,122 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.ml.inference.adaptiveallocations; + +import org.elasticsearch.test.ESTestCase; + +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.lessThan; + +public class KalmanFilter1dTests extends ESTestCase { + + public void testEstimation_equalValues() { + KalmanFilter1d filter = new KalmanFilter1d("test-filter", 100, false); + assertThat(filter.hasValue(), equalTo(false)); + + filter.add(42.0, 9.0, false); + assertThat(filter.hasValue(), equalTo(true)); + assertThat(filter.estimate(), equalTo(42.0)); + assertThat(filter.error(), equalTo(3.0)); + assertThat(filter.lower(), equalTo(39.0)); + assertThat(filter.upper(), equalTo(45.0)); + + // With more data the estimation error should go down. + double previousError = filter.error(); + for (int i = 0; i < 20; i++) { + filter.add(42.0, 9.0, false); + assertThat(filter.estimate(), equalTo(42.0)); + assertThat(filter.error(), lessThan(previousError)); + previousError = filter.error(); + } + } + + public void testEstimation_increasingValues() { + KalmanFilter1d filter = new KalmanFilter1d("test-filter", 100, false); + filter.add(10.0, 1.0, false); + assertThat(filter.estimate(), equalTo(10.0)); + + // As the measured values increase, the estimated value should increase too, + // but it should lag behind. + double previousEstimate = filter.estimate(); + for (double value = 11.0; value < 20.0; value += 1.0) { + filter.add(value, 1.0, false); + assertThat(filter.estimate(), greaterThan(previousEstimate)); + assertThat(filter.estimate(), lessThan(value)); + previousEstimate = filter.estimate(); + } + + // More final values should bring the estimate close to it. + for (int i = 0; i < 20; i++) { + filter.add(20.0, 1.0, false); + } + assertThat(filter.estimate(), greaterThan(19.0)); + assertThat(filter.estimate(), lessThan(20.0)); + } + + public void testEstimation_bigJumpNoAutoDetectDynamicsChanges() { + KalmanFilter1d filter = new KalmanFilter1d("test-filter", 100, false); + filter.add(0.0, 100.0, false); + filter.add(0.0, 1.0, false); + assertThat(filter.estimate(), equalTo(0.0)); + + // Without dynamics change autodetection the estimated value should be + // inbetween the old and the new value. + filter.add(100.0, 1.0, false); + assertThat(filter.estimate(), greaterThan(49.0)); + assertThat(filter.estimate(), lessThan(51.0)); + } + + public void testEstimation_bigJumpWithAutoDetectDynamicsChanges() { + KalmanFilter1d filter = new KalmanFilter1d("test-filter", 100, true); + filter.add(0.0, 100.0, false); + filter.add(0.0, 1.0, false); + assertThat(filter.estimate(), equalTo(0.0)); + + // With dynamics change autodetection the estimated value should jump + // instantly to almost the new value. + filter.add(100.0, 1.0, false); + assertThat(filter.estimate(), greaterThan(99.0)); + assertThat(filter.estimate(), lessThan(100.0)); + } + + public void testEstimation_bigJumpWithExternalDetectDynamicsChange() { + KalmanFilter1d filter = new KalmanFilter1d("test-filter", 100, false); + filter.add(0.0, 100.0, false); + filter.add(0.0, 1.0, false); + assertThat(filter.estimate(), equalTo(0.0)); + + // For external dynamics changes the estimated value should jump + // instantly to almost the new value. + filter.add(100.0, 1.0, true); + assertThat(filter.estimate(), greaterThan(99.0)); + assertThat(filter.estimate(), lessThan(100.0)); + } + + public void testEstimation_differentSmoothing() { + KalmanFilter1d quickFilter = new KalmanFilter1d("test-filter", 1e-3, false); + for (int i = 0; i < 100; i++) { + quickFilter.add(42.0, 1.0, false); + } + assertThat(quickFilter.estimate(), equalTo(42.0)); + // With low smoothing, the value should be close to the new value. + quickFilter.add(77.0, 1.0, false); + assertThat(quickFilter.estimate(), greaterThan(75.0)); + assertThat(quickFilter.estimate(), lessThan(77.0)); + + KalmanFilter1d slowFilter = new KalmanFilter1d("test-filter", 1e3, false); + for (int i = 0; i < 100; i++) { + slowFilter.add(42.0, 1.0, false); + } + assertThat(slowFilter.estimate(), equalTo(42.0)); + // With high smoothing, the value should be close to the old value. + slowFilter.add(77.0, 1.0, false); + assertThat(slowFilter.estimate(), greaterThan(42.0)); + assertThat(slowFilter.estimate(), lessThan(44.0)); + } +} diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java index f08d2735be8a5..1dc44582492aa 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentClusterServiceTests.java @@ -48,6 +48,7 @@ import org.elasticsearch.xpack.core.ml.MlConfigVersion; import org.elasticsearch.xpack.core.ml.MlMetadata; import org.elasticsearch.xpack.core.ml.MlTasks; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.StartDataFrameAnalyticsAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.action.UpdateTrainedModelAssignmentRoutingInfoAction; @@ -277,7 +278,7 @@ public void testUpdateModelRoutingTable() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L), null) .addRoutingEntry(nodeId, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .addRoutingEntry(startedNode, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) @@ -389,7 +390,10 @@ public void testRemoveAssignment() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, randomNonNegativeLong()))) + .addNewAssignment( + modelId, + TrainedModelAssignment.Builder.empty(newParams(modelId, randomNonNegativeLong()), null) + ) .build() ) .build() @@ -450,7 +454,10 @@ public void testCreateAssignment_GivenModelCannotByFullyAllocated_AndScalingIsPo .build(); TrainedModelAssignmentClusterService trainedModelAssignmentClusterService = createClusterService(5); - ClusterState newState = trainedModelAssignmentClusterService.createModelAssignment(currentState, newParams("new-model", 150, 4, 1)); + ClusterState newState = trainedModelAssignmentClusterService.createModelAssignment( + currentState, + new CreateTrainedModelAssignmentAction.Request(newParams("new-model", 150, 4, 1), null) + ); TrainedModelAssignment createdAssignment = TrainedModelAssignmentMetadata.fromState(newState).getDeploymentAssignment("new-model"); assertThat(createdAssignment, is(not(nullValue()))); @@ -466,7 +473,10 @@ public void testCreateAssignment_GivenModelCannotByFullyAllocated_AndScalingIsPo expectThrows( ResourceAlreadyExistsException.class, - () -> trainedModelAssignmentClusterService.createModelAssignment(newState, newParams("new-model", 150)) + () -> trainedModelAssignmentClusterService.createModelAssignment( + newState, + new CreateTrainedModelAssignmentAction.Request(newParams("new-model", 150), null) + ) ); } @@ -495,7 +505,10 @@ public void testCreateAssignment_GivenModelCannotByFullyAllocated_AndScalingIsNo TrainedModelAssignmentClusterService trainedModelAssignmentClusterService = createClusterService(0); ElasticsearchStatusException e = expectThrows( ElasticsearchStatusException.class, - () -> trainedModelAssignmentClusterService.createModelAssignment(currentState, newParams("new-model", 150, 4, 1)) + () -> trainedModelAssignmentClusterService.createModelAssignment( + currentState, + new CreateTrainedModelAssignmentAction.Request(newParams("new-model", 150, 4, 1), null) + ) ); assertThat( @@ -528,7 +541,7 @@ public void testCreateAssignmentWhileResetModeIsTrue() throws InterruptedExcepti CountDownLatch latch = new CountDownLatch(1); trainedModelAssignmentClusterService.createNewModelAssignment( - newParams("new-model", 150), + new CreateTrainedModelAssignmentAction.Request(newParams("new-model", 150), null), new LatchedActionListener<>( ActionListener.wrap( trainedModelAssignment -> fail("assignment should have failed to be created because reset mode is set"), @@ -560,7 +573,7 @@ public void testHaveMlNodesChanged_ReturnsFalseWhenPreviouslyShuttingDownNode_Is TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -597,7 +610,7 @@ public void testHaveMlNodesChanged_ReturnsTrueWhenNodeShutsDownAndWasRoutedTo() TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -614,7 +627,7 @@ public void testHaveMlNodesChanged_ReturnsTrueWhenNodeShutsDownAndWasRoutedTo() TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -641,7 +654,7 @@ public void testHaveMlNodesChanged_ReturnsFalseWhenNodeShutsDownAndWasRoutedTo_B TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STOPPING, "")) ) .build() @@ -658,7 +671,7 @@ public void testHaveMlNodesChanged_ReturnsFalseWhenNodeShutsDownAndWasRoutedTo_B TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -700,7 +713,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -747,7 +760,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -759,7 +772,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -781,7 +794,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -793,7 +806,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -815,7 +828,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -827,7 +840,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -851,7 +864,7 @@ public void testDetectReasonToRebalanceModels() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)).stopAssignment("test") + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null).stopAssignment("test") ) .build() ) @@ -864,7 +877,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -886,7 +899,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .putCustom(NodesShutdownMetadata.TYPE, shutdownMetadata(mlNode2)) @@ -899,7 +912,7 @@ public void testDetectReasonToRebalanceModels() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100))) + .addNewAssignment(model1, TrainedModelAssignment.Builder.empty(newParams(model1, 100), null)) .build() ) .build() @@ -923,12 +936,12 @@ public void testDetectReasonToRebalanceModels() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( model2, - TrainedModelAssignment.Builder.empty(newParams("model-2", 100)) + TrainedModelAssignment.Builder.empty(newParams("model-2", 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .addRoutingEntry(mlNode2, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) @@ -945,12 +958,12 @@ public void testDetectReasonToRebalanceModels() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( model2, - TrainedModelAssignment.Builder.empty(newParams("model-2", 100)) + TrainedModelAssignment.Builder.empty(newParams("model-2", 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .addRoutingEntry(mlNode2, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) @@ -977,12 +990,12 @@ public void testDetectReasonToRebalanceModels() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( model2, - TrainedModelAssignment.Builder.empty(newParams("model-2", 100)) + TrainedModelAssignment.Builder.empty(newParams("model-2", 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .addRoutingEntry(mlNode2, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .stopAssignment("test") @@ -1000,12 +1013,12 @@ public void testDetectReasonToRebalanceModels() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( model2, - TrainedModelAssignment.Builder.empty(newParams("model-2", 100)) + TrainedModelAssignment.Builder.empty(newParams("model-2", 100), null) .addRoutingEntry(mlNode1, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .addRoutingEntry(mlNode2, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) @@ -1032,7 +1045,7 @@ public void testDetectReasonToRebalanceModels_WithNodeShutdowns() { TrainedModelAssignmentMetadata fullModelAllocation = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( model1, - TrainedModelAssignment.Builder.empty(newParams(model1, 100)) + TrainedModelAssignment.Builder.empty(newParams(model1, 100), null) .addRoutingEntry(mlNode1.getId(), new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(mlNode2.getId(), new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) @@ -1227,7 +1240,7 @@ public void testDetectReasonToRebalanceModels_GivenSingleMlJobStopped() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100))) + .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100), null)) .build() ) .build() @@ -1242,7 +1255,7 @@ public void testDetectReasonToRebalanceModels_GivenSingleMlJobStopped() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100))) + .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100), null)) .build() ) .build() @@ -1265,7 +1278,7 @@ public void testDetectReasonToRebalanceModels_GivenOutdatedAssignments() { TrainedModelAssignmentMetadata modelMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(newParams(modelId, 100)) + TrainedModelAssignment.Builder.empty(newParams(modelId, 100), null) .addRoutingEntry(mlNodeId, new RoutingInfo(0, 0, RoutingState.STARTED, "")) ) .build(); @@ -1342,7 +1355,7 @@ public void testDetectReasonToRebalanceModels_GivenMultipleMlJobsStopped() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100))) + .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100), null)) .build() ) .build() @@ -1357,7 +1370,7 @@ public void testDetectReasonToRebalanceModels_GivenMultipleMlJobsStopped() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100))) + .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100), null)) .build() ) .build() @@ -1419,7 +1432,7 @@ public void testDetectReasonToRebalanceModels_GivenMlJobsStarted() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100))) + .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100), null)) .build() ) .build() @@ -1434,7 +1447,7 @@ public void testDetectReasonToRebalanceModels_GivenMlJobsStarted() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100))) + .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, 100), null)) .build() ) .build() @@ -1459,7 +1472,7 @@ public void testAreAssignedNodesRemoved_GivenRemovedNodeThatIsRouted() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) @@ -1491,7 +1504,7 @@ public void testAreAssignedNodesRemoved_GivenRemovedNodeThatIsNotRouted() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build() @@ -1519,7 +1532,7 @@ public void testAreAssignedNodesRemoved_GivenShuttingDownNodeThatIsRouted() { TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) @@ -1563,7 +1576,7 @@ public void testAreAssignedNodesRemoved_GivenShuttingDownNodeThatIsNotRouted() { TrainedModelAssignmentMetadata trainedModelAssignmentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId, 10_000L), null) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -1611,13 +1624,13 @@ public void testRemoveRoutingToUnassignableNodes_RemovesRouteForRemovedNodes() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId1, - TrainedModelAssignment.Builder.empty(newParams(modelId1, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId1, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( modelId2, - TrainedModelAssignment.Builder.empty(newParams(modelId2, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId2, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) @@ -1668,14 +1681,14 @@ public void testRemoveRoutingToUnassignableNodes_AddsAStoppingRouteForShuttingDo TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId1, - TrainedModelAssignment.Builder.empty(newParams(modelId1, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId1, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId3, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( modelId2, - TrainedModelAssignment.Builder.empty(newParams(modelId2, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId2, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId3, new RoutingInfo(1, 1, RoutingState.STARTED, "")) @@ -1728,14 +1741,14 @@ public void testRemoveRoutingToUnassignableNodes_IgnoresARouteThatIsStoppedForSh TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId1, - TrainedModelAssignment.Builder.empty(newParams(modelId1, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId1, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId3, new RoutingInfo(1, 1, RoutingState.STOPPED, "")) ) .addNewAssignment( modelId2, - TrainedModelAssignment.Builder.empty(newParams(modelId2, 10_000L)) + TrainedModelAssignment.Builder.empty(newParams(modelId2, 10_000L), null) .addRoutingEntry(nodeId1, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId2, new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry(nodeId3, new RoutingInfo(1, 1, RoutingState.STOPPED, "")) @@ -1789,12 +1802,12 @@ public void testSetShuttingDownNodeRoutesToStopping_GivenAnAssignmentRoutedToShu TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( availableNodeModelId, - TrainedModelAssignment.Builder.empty(taskParamsRunning) + TrainedModelAssignment.Builder.empty(taskParamsRunning, null) .addRoutingEntry(availableNode, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( shuttingDownModelId, - TrainedModelAssignment.Builder.empty(taskParamsShuttingDown) + TrainedModelAssignment.Builder.empty(taskParamsShuttingDown, null) .addRoutingEntry(shuttingDownNodeId, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -1802,12 +1815,12 @@ public void testSetShuttingDownNodeRoutesToStopping_GivenAnAssignmentRoutedToShu TrainedModelAssignmentMetadata.Builder rebalanced = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( availableNodeModelId, - TrainedModelAssignment.Builder.empty(taskParamsRunning) + TrainedModelAssignment.Builder.empty(taskParamsRunning, null) .addRoutingEntry(availableNode, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( shuttingDownModelId, - TrainedModelAssignment.Builder.empty(taskParamsRunning) + TrainedModelAssignment.Builder.empty(taskParamsRunning, null) .addRoutingEntry(availableNode, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ); @@ -1840,12 +1853,12 @@ public void testSetShuttingDownNodeRoutesToStopping_GivenAnAssignmentRoutedToShu TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( shuttingDownModelId, - TrainedModelAssignment.Builder.empty(taskParamsShuttingDown) + TrainedModelAssignment.Builder.empty(taskParamsShuttingDown, null) .addRoutingEntry(shuttingDownNodeId, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( notShuttingDownModelId, - TrainedModelAssignment.Builder.empty(taskParamsNotShuttingDown) + TrainedModelAssignment.Builder.empty(taskParamsNotShuttingDown, null) .addRoutingEntry(availableNode, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -1853,12 +1866,12 @@ public void testSetShuttingDownNodeRoutesToStopping_GivenAnAssignmentRoutedToShu TrainedModelAssignmentMetadata.Builder rebalanced = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( shuttingDownModelId, - TrainedModelAssignment.Builder.empty(taskParamsShuttingDown) + TrainedModelAssignment.Builder.empty(taskParamsShuttingDown, null) .addRoutingEntry(availableNode, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( notShuttingDownModelId, - TrainedModelAssignment.Builder.empty(taskParamsNotShuttingDown) + TrainedModelAssignment.Builder.empty(taskParamsNotShuttingDown, null) .addRoutingEntry(availableNode, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ); @@ -1897,7 +1910,7 @@ public void testSetShuttingDownNodeRoutesToStopping_GivenAnAssignmentRoutedToShu TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(taskParamsShuttingDown) + TrainedModelAssignment.Builder.empty(taskParamsShuttingDown, null) .addRoutingEntry(disappearingNodeId, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -1905,7 +1918,7 @@ public void testSetShuttingDownNodeRoutesToStopping_GivenAnAssignmentRoutedToShu TrainedModelAssignmentMetadata.Builder rebalanced = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(taskParamsShuttingDown) + TrainedModelAssignment.Builder.empty(taskParamsShuttingDown, null) .addRoutingEntry(availableNode, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ); @@ -1933,7 +1946,7 @@ public void testSetShuttingDownNodeRoutesToStopping_GivenAssignmentDoesNotExist_ TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(taskParamsShuttingDown) + TrainedModelAssignment.Builder.empty(taskParamsShuttingDown, null) .addRoutingEntry(shuttingDownNodeId, new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -2006,7 +2019,10 @@ public void testSetAllocationToStopping() { .putCustom( TrainedModelAssignmentMetadata.NAME, TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(newParams(modelId, randomNonNegativeLong()))) + .addNewAssignment( + modelId, + TrainedModelAssignment.Builder.empty(newParams(modelId, randomNonNegativeLong()), null) + ) .build() ) .build() diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java index 6c5223eae4d99..dec85bff87d67 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentMetadataTests.java @@ -64,7 +64,7 @@ public void testIsAssigned() { TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( allocatedDeploymentId, - TrainedModelAssignment.Builder.empty(randomParams(allocatedDeploymentId, allocatedModelId)) + TrainedModelAssignment.Builder.empty(randomParams(allocatedDeploymentId, allocatedModelId), null) ) .build(); assertThat(metadata.isAssigned(allocatedDeploymentId), is(true)); @@ -78,7 +78,7 @@ public void testModelIsDeployed() { TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( allocatedDeploymentId, - TrainedModelAssignment.Builder.empty(randomParams(allocatedDeploymentId, allocatedModelId)) + TrainedModelAssignment.Builder.empty(randomParams(allocatedDeploymentId, allocatedModelId), null) ) .build(); assertThat(metadata.modelIsDeployed(allocatedDeploymentId), is(false)); @@ -92,9 +92,9 @@ public void testGetDeploymentsUsingModel() { String deployment2 = "test_deployment_2"; String deployment3 = "test_deployment_3"; TrainedModelAssignmentMetadata metadata = TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(deployment1, TrainedModelAssignment.Builder.empty(randomParams(deployment1, modelId1))) - .addNewAssignment(deployment2, TrainedModelAssignment.Builder.empty(randomParams(deployment2, modelId1))) - .addNewAssignment(deployment3, TrainedModelAssignment.Builder.empty(randomParams(deployment3, "different_model"))) + .addNewAssignment(deployment1, TrainedModelAssignment.Builder.empty(randomParams(deployment1, modelId1), null)) + .addNewAssignment(deployment2, TrainedModelAssignment.Builder.empty(randomParams(deployment2, modelId1), null)) + .addNewAssignment(deployment3, TrainedModelAssignment.Builder.empty(randomParams(deployment3, "different_model"), null)) .build(); var assignments = metadata.getDeploymentsUsingModel(modelId1); assertThat(assignments, hasSize(2)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java index 2444134ce2920..9fbc2b43f1137 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentNodeServiceTests.java @@ -49,10 +49,12 @@ import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; import static org.elasticsearch.xpack.ml.MachineLearning.UTILITY_THREAD_POOL_NAME; import static org.elasticsearch.xpack.ml.inference.assignment.TrainedModelAssignmentClusterServiceTests.shutdownMetadata; import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.notNullValue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; @@ -91,19 +93,13 @@ public void setupObjects() { taskManager = new TaskManager(Settings.EMPTY, threadPool, Collections.emptySet()); deploymentManager = mock(DeploymentManager.class); doAnswer(invocationOnMock -> { - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; - listener.onResponse(invocationOnMock.getArguments()[0]); + ActionListener listener = invocationOnMock.getArgument(1); + listener.onResponse(invocationOnMock.getArgument(0)); return null; }).when(deploymentManager).startDeployment(any(), any()); doAnswer(invocationOnMock -> { - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; - listener.onResponse(null); - return null; - }).when(deploymentManager).stopAfterCompletingPendingWork(any()); - - doAnswer(invocationOnMock -> { - ActionListener listener = (ActionListener) invocationOnMock.getArguments()[1]; + ActionListener listener = invocationOnMock.getArgument(1); listener.onResponse(AcknowledgedResponse.TRUE); return null; }).when(trainedModelAssignmentService).updateModelAssignmentState(any(), any()); @@ -114,15 +110,33 @@ public void shutdown() throws InterruptedException { terminate(threadPool); } - public void testLoadQueuedModels_GivenNoQueuedModels() { - TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); - + public void testLoadQueuedModels_GivenNoQueuedModels() throws InterruptedException { // When there are no queued models - trainedModelAssignmentNodeService.loadQueuedModels(); + loadQueuedModels(createService()); verify(deploymentManager, never()).startDeployment(any(), any()); } - public void testLoadQueuedModels() { + private void loadQueuedModels(TrainedModelAssignmentNodeService trainedModelAssignmentNodeService) throws InterruptedException { + loadQueuedModels(trainedModelAssignmentNodeService, false); + } + + private void loadQueuedModels(TrainedModelAssignmentNodeService trainedModelAssignmentNodeService, boolean expectedRunImmediately) + throws InterruptedException { + var latch = new CountDownLatch(1); + var actual = new AtomicReference(); // AtomicReference for nullable + trainedModelAssignmentNodeService.loadQueuedModels( + ActionListener.runAfter(ActionListener.wrap(actual::set, e -> {}), latch::countDown) + ); + assertTrue("Timed out waiting for loadQueuedModels to finish.", latch.await(10, TimeUnit.SECONDS)); + assertThat("Test failed to call the onResponse handler.", actual.get(), notNullValue()); + assertThat( + "We should rerun immediately if there are still model loading tasks to process.", + actual.get(), + equalTo(expectedRunImmediately) + ); + } + + public void testLoadQueuedModels() throws InterruptedException { TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); String modelToLoad = "loading-model"; @@ -136,7 +150,8 @@ public void testLoadQueuedModels() { trainedModelAssignmentNodeService.prepareModelToLoad(newParams(deploymentId, modelToLoad)); trainedModelAssignmentNodeService.prepareModelToLoad(newParams(anotherDeployment, anotherModel)); - trainedModelAssignmentNodeService.loadQueuedModels(); + loadQueuedModels(trainedModelAssignmentNodeService, true); + loadQueuedModels(trainedModelAssignmentNodeService, false); ArgumentCaptor taskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); ArgumentCaptor requestCapture = ArgumentCaptor.forClass( @@ -157,11 +172,11 @@ public void testLoadQueuedModels() { // Since models are loaded, there shouldn't be any more loadings to occur trainedModelAssignmentNodeService.prepareModelToLoad(newParams(anotherDeployment, anotherModel)); - trainedModelAssignmentNodeService.loadQueuedModels(); + loadQueuedModels(trainedModelAssignmentNodeService); verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); } - public void testLoadQueuedModelsWhenFailureIsRetried() { + public void testLoadQueuedModelsWhenFailureIsRetried() throws InterruptedException { String modelToLoad = "loading-model"; String failedModelToLoad = "failed-search-loading-model"; String deploymentId = "foo"; @@ -174,9 +189,9 @@ public void testLoadQueuedModelsWhenFailureIsRetried() { trainedModelAssignmentNodeService.prepareModelToLoad(newParams(deploymentId, modelToLoad)); trainedModelAssignmentNodeService.prepareModelToLoad(newParams(failedDeploymentId, failedModelToLoad)); - trainedModelAssignmentNodeService.loadQueuedModels(); - - trainedModelAssignmentNodeService.loadQueuedModels(); + loadQueuedModels(trainedModelAssignmentNodeService, true); + loadQueuedModels(trainedModelAssignmentNodeService, false); + loadQueuedModels(trainedModelAssignmentNodeService, false); ArgumentCaptor startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); ArgumentCaptor requestCapture = ArgumentCaptor.forClass( @@ -199,7 +214,7 @@ public void testLoadQueuedModelsWhenFailureIsRetried() { verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); } - public void testLoadQueuedModelsWhenStopped() { + public void testLoadQueuedModelsWhenStopped() throws InterruptedException { TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); // When there are no queued models @@ -209,7 +224,12 @@ public void testLoadQueuedModelsWhenStopped() { trainedModelAssignmentNodeService.prepareModelToLoad(newParams(modelToLoad, modelToLoad)); trainedModelAssignmentNodeService.stop(); - trainedModelAssignmentNodeService.loadQueuedModels(); + var latch = new CountDownLatch(1); + trainedModelAssignmentNodeService.loadQueuedModels(ActionListener.running(latch::countDown)); + assertTrue( + "loadQueuedModels should immediately call the listener without forking to another thread.", + latch.await(0, TimeUnit.SECONDS) + ); verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); } @@ -231,7 +251,8 @@ public void testLoadQueuedModelsWhenTaskIsStopped() throws Exception { trainedModelAssignmentNodeService.prepareModelToLoad(newParams(loadingDeploymentId, modelToLoad)); trainedModelAssignmentNodeService.prepareModelToLoad(newParams(stoppedLoadingDeploymentId, stoppedModelToLoad)); trainedModelAssignmentNodeService.getTask(stoppedLoadingDeploymentId).stop("testing", false, ActionListener.noop()); - trainedModelAssignmentNodeService.loadQueuedModels(); + loadQueuedModels(trainedModelAssignmentNodeService, true); + loadQueuedModels(trainedModelAssignmentNodeService, false); assertBusy(() -> { ArgumentCaptor stoppedTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); @@ -283,15 +304,8 @@ public void testLoadQueuedModelsWhenOneFails() throws InterruptedException { trainedModelAssignmentNodeService.prepareModelToLoad(newParams(loadingDeploymentId, modelToLoad)); trainedModelAssignmentNodeService.prepareModelToLoad(newParams(failedLoadingDeploymentId, failedModelToLoad)); - CountDownLatch latch = new CountDownLatch(1); - doAnswer(invocationOnMock -> { - latch.countDown(); - return null; - }).when(deploymentManager).stopDeployment(any()); - - trainedModelAssignmentNodeService.loadQueuedModels(); - - latch.await(5, TimeUnit.SECONDS); + loadQueuedModels(trainedModelAssignmentNodeService, true); + loadQueuedModels(trainedModelAssignmentNodeService, false); ArgumentCaptor startTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); ArgumentCaptor requestCapture = ArgumentCaptor.forClass( @@ -318,7 +332,7 @@ public void testLoadQueuedModelsWhenOneFails() throws InterruptedException { verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); } - public void testClusterChangedWithResetMode() { + public void testClusterChangedWithResetMode() throws InterruptedException { final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build(); String modelOne = "model-1"; @@ -339,17 +353,17 @@ public void testClusterChangedWithResetMode() { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelOne, - TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne)) + TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne), null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( modelTwo, - TrainedModelAssignment.Builder.empty(newParams(deploymentTwo, modelTwo)) + TrainedModelAssignment.Builder.empty(newParams(deploymentTwo, modelTwo), null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( notUsedModel, - TrainedModelAssignment.Builder.empty(newParams(notUsedDeployment, notUsedModel)) + TrainedModelAssignment.Builder.empty(newParams(notUsedDeployment, notUsedModel), null) .addRoutingEntry("some-other-node", new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -362,7 +376,7 @@ public void testClusterChangedWithResetMode() { ); trainedModelAssignmentNodeService.clusterChanged(event); - trainedModelAssignmentNodeService.loadQueuedModels(); + loadQueuedModels(trainedModelAssignmentNodeService); verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); } @@ -397,7 +411,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_CallsStop TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(taskParams) + TrainedModelAssignment.Builder.empty(taskParams, null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STOPPING, "")) ) .build() @@ -450,7 +464,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNode_ButOtherA TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(taskParams) + TrainedModelAssignment.Builder.empty(taskParams, null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STOPPING, "")) .addRoutingEntry(node2, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) @@ -480,7 +494,6 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlready String modelOne = "model-1"; String deploymentOne = "deployment-1"; - ArgumentCaptor stopParamsCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); var taskParams = newParams(deploymentOne, modelOne); ClusterChangedEvent event = new ClusterChangedEvent( @@ -494,7 +507,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeButAlready TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(taskParams) + TrainedModelAssignment.Builder.empty(taskParams, null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STOPPING, "")) ) .build() @@ -535,7 +548,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeWithStarti TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(taskParams) + TrainedModelAssignment.Builder.empty(taskParams, null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -558,7 +571,7 @@ public void testClusterChanged_WhenAssigmentIsRoutedToShuttingDownNodeWithStarti verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); } - public void testClusterChanged_WhenAssigmentIsStopping_DoesNotAddModelToBeLoaded() { + public void testClusterChanged_WhenAssigmentIsStopping_DoesNotAddModelToBeLoaded() throws InterruptedException { final TrainedModelAssignmentNodeService trainedModelAssignmentNodeService = createService(); final DiscoveryNodes nodes = DiscoveryNodes.builder().localNodeId(NODE_ID).add(DiscoveryNodeUtils.create(NODE_ID, NODE_ID)).build(); String modelOne = "model-1"; @@ -577,7 +590,7 @@ public void testClusterChanged_WhenAssigmentIsStopping_DoesNotAddModelToBeLoaded TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(taskParams) + TrainedModelAssignment.Builder.empty(taskParams, null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .stopAssignment("stopping") ) @@ -592,7 +605,7 @@ public void testClusterChanged_WhenAssigmentIsStopping_DoesNotAddModelToBeLoaded // trainedModelAssignmentNodeService.prepareModelToLoad(taskParams); trainedModelAssignmentNodeService.clusterChanged(event); - trainedModelAssignmentNodeService.loadQueuedModels(); + loadQueuedModels(trainedModelAssignmentNodeService); verify(deploymentManager, never()).startDeployment(any(), any()); verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); @@ -626,12 +639,12 @@ public void testClusterChanged() throws Exception { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne)) + TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne), null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( deploymentTwo, - TrainedModelAssignment.Builder.empty(newParams(deploymentTwo, modelTwo)) + TrainedModelAssignment.Builder.empty(newParams(deploymentTwo, modelTwo), null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .updateExistingRoutingEntry( NODE_ID, @@ -645,7 +658,7 @@ public void testClusterChanged() throws Exception { ) .addNewAssignment( previouslyUsedDeployment, - TrainedModelAssignment.Builder.empty(newParams(previouslyUsedDeployment, previouslyUsedModel)) + TrainedModelAssignment.Builder.empty(newParams(previouslyUsedDeployment, previouslyUsedModel), null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) .updateExistingRoutingEntry( NODE_ID, @@ -659,7 +672,7 @@ public void testClusterChanged() throws Exception { ) .addNewAssignment( notUsedDeployment, - TrainedModelAssignment.Builder.empty(newParams(notUsedDeployment, notUsedModel)) + TrainedModelAssignment.Builder.empty(newParams(notUsedDeployment, notUsedModel), null) .addRoutingEntry("some-other-node", new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -684,17 +697,17 @@ public void testClusterChanged() throws Exception { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne)) + TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne), null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( deploymentTwo, - TrainedModelAssignment.Builder.empty(newParams(deploymentTwo, modelTwo)) + TrainedModelAssignment.Builder.empty(newParams(deploymentTwo, modelTwo), null) .addRoutingEntry("some-other-node", new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .addNewAssignment( notUsedDeployment, - TrainedModelAssignment.Builder.empty(newParams(notUsedDeployment, notUsedModel)) + TrainedModelAssignment.Builder.empty(newParams(notUsedDeployment, notUsedModel), null) .addRoutingEntry("some-other-node", new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -706,7 +719,8 @@ public void testClusterChanged() throws Exception { ); trainedModelAssignmentNodeService.clusterChanged(event); - trainedModelAssignmentNodeService.loadQueuedModels(); + loadQueuedModels(trainedModelAssignmentNodeService, true); + loadQueuedModels(trainedModelAssignmentNodeService, false); assertBusy(() -> { ArgumentCaptor stoppedTaskCapture = ArgumentCaptor.forClass(TrainedModelDeploymentTask.class); @@ -737,7 +751,7 @@ public void testClusterChanged() throws Exception { TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne)) + TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne), null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 1, RoutingState.STARTING, "")) ) .build() @@ -749,7 +763,7 @@ public void testClusterChanged() throws Exception { ); trainedModelAssignmentNodeService.clusterChanged(event); - trainedModelAssignmentNodeService.loadQueuedModels(); + loadQueuedModels(trainedModelAssignmentNodeService); verifyNoMoreInteractions(deploymentManager, trainedModelAssignmentService); } @@ -764,7 +778,8 @@ public void testClusterChanged_GivenAllStartedAssignments_AndNonMatchingTargetAl givenAssignmentsInClusterStateForModels(List.of(deploymentOne, deploymentTwo), List.of(modelOne, modelTwo)); trainedModelAssignmentNodeService.prepareModelToLoad(newParams(deploymentOne, modelOne)); trainedModelAssignmentNodeService.prepareModelToLoad(newParams(deploymentTwo, modelTwo)); - trainedModelAssignmentNodeService.loadQueuedModels(); + loadQueuedModels(trainedModelAssignmentNodeService, true); + loadQueuedModels(trainedModelAssignmentNodeService, false); ClusterChangedEvent event = new ClusterChangedEvent( "shouldUpdateAllocations", @@ -778,12 +793,12 @@ public void testClusterChanged_GivenAllStartedAssignments_AndNonMatchingTargetAl TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentOne, - TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne)) + TrainedModelAssignment.Builder.empty(newParams(deploymentOne, modelOne), null) .addRoutingEntry(NODE_ID, new RoutingInfo(1, 3, RoutingState.STARTED, "")) ) .addNewAssignment( deploymentTwo, - TrainedModelAssignment.Builder.empty(newParams(deploymentTwo, modelTwo)) + TrainedModelAssignment.Builder.empty(newParams(deploymentTwo, modelTwo), null) .addRoutingEntry(NODE_ID, new RoutingInfo(2, 1, RoutingState.STARTED, "")) ) .build() @@ -830,7 +845,7 @@ private void givenAssignmentsInClusterStateForModels(List deploymentIds, for (int i = 0; i < modelIds.size(); i++) { builder.addNewAssignment( deploymentIds.get(i), - TrainedModelAssignment.Builder.empty(newParams(deploymentIds.get(i), modelIds.get(i))) + TrainedModelAssignment.Builder.empty(newParams(deploymentIds.get(i), modelIds.get(i)), null) .addRoutingEntry("test-node", new RoutingInfo(1, 1, RoutingState.STARTING, "")) ); } diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java index 53b737b38c284..65a974e04045e 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/TrainedModelAssignmentRebalancerTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.common.unit.ByteSizeValue; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.ml.action.CreateTrainedModelAssignmentAction; import org.elasticsearch.xpack.core.ml.action.StartTrainedModelDeploymentAction; import org.elasticsearch.xpack.core.ml.inference.assignment.AssignmentState; import org.elasticsearch.xpack.core.ml.inference.assignment.Priority; @@ -61,11 +62,12 @@ public void testRebalance_GivenAllAssignmentsAreSatisfied_ShouldMakeNoChanges() TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentId1, - TrainedModelAssignment.Builder.empty(taskParams1).addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + TrainedModelAssignment.Builder.empty(taskParams1, null) + .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( deploymentId2, - TrainedModelAssignment.Builder.empty(taskParams2) + TrainedModelAssignment.Builder.empty(taskParams2, null) .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry("node-2", new RoutingInfo(3, 3, RoutingState.STARTED, "")) ) @@ -101,11 +103,12 @@ public void testRebalance_GivenAllAssignmentsAreSatisfied_GivenOutdatedRoutingEn TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deploymentId1, - TrainedModelAssignment.Builder.empty(taskParams1).addRoutingEntry("node-1", new RoutingInfo(0, 0, RoutingState.STARTED, "")) + TrainedModelAssignment.Builder.empty(taskParams1, null) + .addRoutingEntry("node-1", new RoutingInfo(0, 0, RoutingState.STARTED, "")) ) .addNewAssignment( deploymentId2, - TrainedModelAssignment.Builder.empty(taskParams2) + TrainedModelAssignment.Builder.empty(taskParams2, null) .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry("node-2", new RoutingInfo(3, 3, RoutingState.STARTED, "")) ) @@ -140,11 +143,18 @@ public void testRebalance_GivenModelToAddAlreadyExists() { String modelId = "model-to-add"; StartTrainedModelDeploymentAction.TaskParams taskParams = normalPriorityParams(modelId, modelId, 1024L, 1, 1); TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(taskParams)) + .addNewAssignment(modelId, TrainedModelAssignment.Builder.empty(taskParams, null)) .build(); expectThrows( ResourceAlreadyExistsException.class, - () -> new TrainedModelAssignmentRebalancer(currentMetadata, Map.of(), Map.of(), Optional.of(taskParams), 1, false).rebalance() + () -> new TrainedModelAssignmentRebalancer( + currentMetadata, + Map.of(), + Map.of(), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), + 1, + false + ).rebalance() ); } @@ -157,7 +167,7 @@ public void testRebalance_GivenFirstModelToAdd_NoMLNodes() throws Exception { currentMetadata, Map.of(), Map.of(), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); @@ -185,7 +195,7 @@ public void testRebalance_GivenFirstModelToAdd_NotEnoughProcessors() throws Exce currentMetadata, nodeLoads, Map.of(List.of(), List.of(node)), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); @@ -222,7 +232,7 @@ public void testRebalance_GivenFirstModelToAdd_NotEnoughMemory() throws Exceptio currentMetadata, nodeLoads, Map.of(), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); @@ -259,7 +269,7 @@ public void testRebalance_GivenFirstModelToAdd_ErrorDetectingNodeLoad() throws E currentMetadata, nodeLoads, Map.of(), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); @@ -296,7 +306,7 @@ public void testRebalance_GivenProblemsOnMultipleNodes() throws Exception { currentMetadata, nodeLoads, Map.of(List.of(), List.of(node1, node2)), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); @@ -330,7 +340,7 @@ public void testRebalance_GivenFirstModelToAdd_FitsFully() throws Exception { currentMetadata, nodeLoads, Map.of(List.of(), List.of(node1)), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); @@ -357,7 +367,7 @@ public void testRebalance_GivenModelToAdd_AndPreviousAssignments_AndTwoNodes_All TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( previousDeploymentId, - TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeploymentId, previousDeploymentId, 1024L, 3, 2)) + TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeploymentId, previousDeploymentId, 1024L, 3, 2), null) .addRoutingEntry("node-1", new RoutingInfo(2, 2, RoutingState.STARTED, "")) .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) @@ -370,7 +380,7 @@ public void testRebalance_GivenModelToAdd_AndPreviousAssignments_AndTwoNodes_All currentMetadata, nodeLoads, Map.of(List.of(), List.of(node1, node2)), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); @@ -416,13 +426,13 @@ public void testRebalance_GivenPreviousAssignments_AndNewNode() throws Exception TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( previousDeployment1Id, - TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment1Id, 1024L, 3, 2)) + TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment1Id, 1024L, 3, 2), null) .addRoutingEntry("node-1", new RoutingInfo(2, 2, RoutingState.STARTED, "")) .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( previousDeployment2Id, - TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment2Id, 1024L, 4, 1)) + TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment2Id, 1024L, 4, 1), null) .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -483,13 +493,13 @@ public void testRebalance_GivenPreviousAssignments_AndRemovedNode_AndRemainingNo TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( previousDeployment1Id, - TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment1Id, 1024L, 3, 2)) + TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment1Id, 1024L, 3, 2), null) .addRoutingEntry("node-1", new RoutingInfo(2, 2, RoutingState.STARTED, "")) .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( previousDeployment2Id, - TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment2Id, 1024L, 4, 1)) + TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment2Id, 1024L, 4, 1), null) .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -554,13 +564,13 @@ public void testRebalance_GivenPreviousAssignments_AndRemovedNode_AndRemainingNo TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( previousDeployment1Id, - TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment1Id, 1024L, 3, 2)) + TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment1Id, 1024L, 3, 2), null) .addRoutingEntry("node-1", new RoutingInfo(2, 2, RoutingState.STARTED, "")) .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .addNewAssignment( previousDeployment2Id, - TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment2Id, 1024L, 1, 1)) + TrainedModelAssignment.Builder.empty(normalPriorityParams(previousDeployment2Id, 1024L, 1, 1), null) .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -610,7 +620,7 @@ public void testRebalance_GivenFailedAssignment_RestartsAssignment() throws Exce TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId, - TrainedModelAssignment.Builder.empty(normalPriorityParams(modelId, 1024L, 1, 1)) + TrainedModelAssignment.Builder.empty(normalPriorityParams(modelId, 1024L, 1, 1), null) .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.FAILED, "some error")) ) .build(); @@ -656,7 +666,7 @@ public void testRebalance_GivenLowPriorityModelToAdd_OnlyModel_NotEnoughMemory() currentMetadata, nodeLoads, Map.of(), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); @@ -693,7 +703,7 @@ public void testRebalance_GivenLowPriorityModelToAdd_NotEnoughMemoryNorProcessor TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( deployment2, - TrainedModelAssignment.Builder.empty(taskParams2) + TrainedModelAssignment.Builder.empty(taskParams2, null) .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) .addRoutingEntry("node-2", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) @@ -703,7 +713,7 @@ public void testRebalance_GivenLowPriorityModelToAdd_NotEnoughMemoryNorProcessor currentMetadata, nodeLoads, Map.of(List.of("zone-1"), List.of(node1), List.of("zone-2"), List.of(node2)), - Optional.of(taskParams1), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams1, null)), 1, false ).rebalance().build(); @@ -735,8 +745,8 @@ public void testRebalance_GivenMixedPriorityModels_NotEnoughMemoryForLowPriority String modelId2 = "model-2"; StartTrainedModelDeploymentAction.TaskParams taskParams2 = normalPriorityParams(modelId2, ByteSizeValue.ofMb(300).getBytes(), 1, 1); TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId1, TrainedModelAssignment.Builder.empty(taskParams1)) - .addNewAssignment(modelId2, TrainedModelAssignment.Builder.empty(taskParams2)) + .addNewAssignment(modelId1, TrainedModelAssignment.Builder.empty(taskParams1, null)) + .addNewAssignment(modelId2, TrainedModelAssignment.Builder.empty(taskParams2, null)) .build(); TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer( @@ -786,10 +796,11 @@ public void testRebalance_GivenMixedPriorityModels_TwoZones_EachNodeCanHoldOneMo String modelId2 = "model-2"; StartTrainedModelDeploymentAction.TaskParams taskParams2 = normalPriorityParams(modelId2, ByteSizeValue.ofMb(300).getBytes(), 1, 1); TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId1, TrainedModelAssignment.Builder.empty(taskParams1)) + .addNewAssignment(modelId1, TrainedModelAssignment.Builder.empty(taskParams1, null)) .addNewAssignment( modelId2, - TrainedModelAssignment.Builder.empty(taskParams2).addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + TrainedModelAssignment.Builder.empty(taskParams2, null) + .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -844,8 +855,8 @@ public void testRebalance_GivenModelUsingAllCpu_FittingLowPriorityModelCanStart( String modelId2 = "model-2"; StartTrainedModelDeploymentAction.TaskParams taskParams2 = normalPriorityParams(modelId2, ByteSizeValue.ofMb(300).getBytes(), 1, 1); TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId1, TrainedModelAssignment.Builder.empty(taskParams1)) - .addNewAssignment(modelId2, TrainedModelAssignment.Builder.empty(taskParams2)) + .addNewAssignment(modelId1, TrainedModelAssignment.Builder.empty(taskParams1, null)) + .addNewAssignment(modelId2, TrainedModelAssignment.Builder.empty(taskParams2, null)) .build(); TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer( @@ -895,8 +906,8 @@ public void testRebalance_GivenMultipleLowPriorityModels_AndMultipleNodes() thro String modelId2 = "model-2"; StartTrainedModelDeploymentAction.TaskParams taskParams2 = lowPriorityParams(modelId2, ByteSizeValue.ofMb(100).getBytes()); TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() - .addNewAssignment(modelId1, TrainedModelAssignment.Builder.empty(taskParams1)) - .addNewAssignment(modelId2, TrainedModelAssignment.Builder.empty(taskParams2)) + .addNewAssignment(modelId1, TrainedModelAssignment.Builder.empty(taskParams1, null)) + .addNewAssignment(modelId2, TrainedModelAssignment.Builder.empty(taskParams2, null)) .build(); TrainedModelAssignmentMetadata result = new TrainedModelAssignmentRebalancer( @@ -946,7 +957,8 @@ public void testRebalance_GivenNormalPriorityModelToLoad_EvictsLowPriorityModel( TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId1, - TrainedModelAssignment.Builder.empty(taskParams1).addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + TrainedModelAssignment.Builder.empty(taskParams1, null) + .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -954,7 +966,7 @@ public void testRebalance_GivenNormalPriorityModelToLoad_EvictsLowPriorityModel( currentMetadata, nodeLoads, Map.of(List.of(), List.of(node1)), - Optional.of(taskParams2), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams2, null)), 1, false ).rebalance().build(); @@ -999,7 +1011,8 @@ public void testRebalance_GivenNormalPriorityModelToLoad_AndLowPriorityModelCanS TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId1, - TrainedModelAssignment.Builder.empty(taskParams1).addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + TrainedModelAssignment.Builder.empty(taskParams1, null) + .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -1007,7 +1020,7 @@ public void testRebalance_GivenNormalPriorityModelToLoad_AndLowPriorityModelCanS currentMetadata, nodeLoads, Map.of(List.of(), List.of(node1, node2)), - Optional.of(taskParams2), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams2, null)), 1, false ).rebalance().build(); @@ -1052,7 +1065,8 @@ public void testRebalance_GivenNormalPriorityModelToLoad_AndLowPriorityModelMust TrainedModelAssignmentMetadata currentMetadata = TrainedModelAssignmentMetadata.Builder.empty() .addNewAssignment( modelId1, - TrainedModelAssignment.Builder.empty(taskParams1).addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) + TrainedModelAssignment.Builder.empty(taskParams1, null) + .addRoutingEntry("node-1", new RoutingInfo(1, 1, RoutingState.STARTED, "")) ) .build(); @@ -1060,7 +1074,7 @@ public void testRebalance_GivenNormalPriorityModelToLoad_AndLowPriorityModelMust currentMetadata, nodeLoads, Map.of(List.of(), List.of(node1, node2)), - Optional.of(taskParams2), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams2, null)), 1, false ).rebalance().build(); @@ -1107,7 +1121,7 @@ public void testRebalance_GivenFirstModelToAdd_GivenScalingProcessorSetting() { currentMetadata, nodeLoads, Map.of(List.of(), List.of(node)), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 2, false ).rebalance().build(); @@ -1130,7 +1144,7 @@ public void testRebalance_GivenFirstModelToAdd_GivenScalingProcessorSetting() { currentMetadata, nodeLoads, Map.of(List.of(), List.of(node)), - Optional.of(taskParams), + Optional.of(new CreateTrainedModelAssignmentAction.Request(taskParams, null)), 1, false ).rebalance().build(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducerTests.java index 85fc83f775670..603eda65fbd51 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AllocationReducerTests.java @@ -181,7 +181,8 @@ private static TrainedModelAssignment createAssignment( Priority.NORMAL, randomNonNegativeLong(), randomNonNegativeLong() - ) + ), + null ); allocationsByNode.entrySet() .stream() diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java index cbbb38f1d1ddd..d84c04f0c41f1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlanTests.java @@ -25,14 +25,14 @@ public class AssignmentPlanTests extends ESTestCase { public void testBuilderCtor_GivenDuplicateNode() { Node n = new Node("n_1", 100, 4); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0, null, 0, 0); expectThrows(IllegalArgumentException.class, () -> AssignmentPlan.builder(List.of(n, n), List.of(m))); } public void testBuilderCtor_GivenDuplicateModel() { Node n = new Node("n_1", 100, 4); - Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0, 0, 0); + Deployment m = new AssignmentPlan.Deployment("m_1", 40, 1, 2, Map.of(), 0, null, 0, 0); expectThrows(IllegalArgumentException.class, () -> AssignmentPlan.builder(List.of(n), List.of(m, m))); } @@ -41,7 +41,17 @@ public void testAssignModelToNode_GivenNoPreviousAssignment() { Node n = new Node("n_1", ByteSizeValue.ofMb(350).getBytes(), 4); { // old memory format - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(40).getBytes(), 1, 2, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(40).getBytes(), + 1, + 2, + Map.of(), + 0, + null, + 0, + 0 + ); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -71,6 +81,7 @@ public void testAssignModelToNode_GivenNoPreviousAssignment() { 2, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(30).getBytes() ); @@ -107,6 +118,7 @@ public void testAssignModelToNode_GivenNewPlanSatisfiesCurrentAssignment() { 2, Map.of("n_1", 1), 0, + null, 0, 0 ); @@ -134,6 +146,7 @@ public void testAssignModelToNode_GivenNewPlanSatisfiesCurrentAssignment() { 2, Map.of("n_1", 1), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(25).getBytes() ); @@ -160,7 +173,7 @@ public void testAssignModelToNode_GivenNewPlanDoesNotSatisfyCurrentAssignment() Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 4); { // old memory format - Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 0, 0, 0); + Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 0, null, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -186,6 +199,7 @@ public void testAssignModelToNode_GivenNewPlanDoesNotSatisfyCurrentAssignment() 2, Map.of("n_1", 2), 0, + null, ByteSizeValue.ofMb(250).getBytes(), ByteSizeValue.ofMb(25).getBytes() ); @@ -209,7 +223,7 @@ public void testAssignModelToNode_GivenNewPlanDoesNotSatisfyCurrentAssignment() public void testAssignModelToNode_GivenPreviouslyUnassignedModelDoesNotFit() { Node n = new Node("n_1", ByteSizeValue.ofMb(340 - 1).getBytes(), 4); - Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 2, Map.of(), 0, 0, 0); + Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 2, Map.of(), 0, null, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 1)); @@ -227,6 +241,7 @@ public void testAssignModelToNode_GivenPreviouslyAssignedModelDoesNotFit() { 2, Map.of("n_1", 1), 0, + null, 0, 0 ); @@ -249,6 +264,7 @@ public void testAssignModelToNode_GivenPreviouslyAssignedModelDoesNotFit() { 2, Map.of("n_1", 1), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(5).getBytes() ); @@ -266,7 +282,7 @@ public void testAssignModelToNode_GivenPreviouslyAssignedModelDoesNotFit() { public void testAssignModelToNode_GivenNotEnoughCores_AndSingleThreadPerAllocation() { Node n = new Node("n_1", ByteSizeValue.ofMb(500).getBytes(), 4); - Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 5, 1, Map.of(), 0, 0, 0); + Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 5, 1, Map.of(), 0, null, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 5)); @@ -279,7 +295,17 @@ public void testAssignModelToNode_GivenNotEnoughCores_AndSingleThreadPerAllocati public void testAssignModelToNode_GivenNotEnoughCores_AndMultipleThreadsPerAllocation() { Node n = new Node("n_1", ByteSizeValue.ofMb(500).getBytes(), 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 3, 2, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 3, + 2, + Map.of(), + 0, + null, + 0, + 0 + ); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); Exception e = expectThrows(IllegalArgumentException.class, () -> builder.assignModelToNode(m, n, 3)); @@ -299,6 +325,7 @@ public void testAssignModelToNode_GivenSameModelAssignedTwice() { 2, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(50).getBytes() ); @@ -335,7 +362,7 @@ public void testAssignModelToNode_GivenSameModelAssignedTwice() { public void testCanAssign_GivenPreviouslyUnassignedModelDoesNotFit() { Node n = new Node("n_1", 100, 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 101, 1, 1, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", 101, 1, 1, Map.of(), 0, null, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -346,7 +373,7 @@ public void testCanAssign_GivenPreviouslyAssignedModelDoesNotFit() { Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); { // old memory format - Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(31).getBytes(), 1, 1, Map.of("n_1", 1), 0, 0, 0); + Deployment m = new Deployment("m_1", ByteSizeValue.ofMb(31).getBytes(), 1, 1, Map.of("n_1", 1), 0, null, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); assertThat(builder.canAssign(m, n, 1), is(true)); } @@ -359,6 +386,7 @@ public void testCanAssign_GivenPreviouslyAssignedModelDoesNotFit() { 1, Map.of("n_1", 1), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); @@ -369,7 +397,17 @@ public void testCanAssign_GivenPreviouslyAssignedModelDoesNotFit() { public void testCanAssign_GivenEnoughMemory() { Node n = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 5); - AssignmentPlan.Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 3, 2, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment m = new AssignmentPlan.Deployment( + "m_1", + ByteSizeValue.ofMb(100).getBytes(), + 3, + 2, + Map.of(), + 0, + null, + 0, + 0 + ); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); @@ -384,7 +422,7 @@ public void testCompareTo_GivenDifferenceInPreviousAssignments() { Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); { - Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 3, 2, Map.of("n_1", 2), 0, 0, 0); + Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 3, 2, Map.of("n_1", 2), 0, null, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); builder.assignModelToNode(m, n, 2); planSatisfyingPreviousAssignments = builder.build(); @@ -397,6 +435,7 @@ public void testCompareTo_GivenDifferenceInPreviousAssignments() { 2, Map.of("n_1", 3), 0, + null, 0, 0 ); @@ -420,6 +459,7 @@ public void testCompareTo_GivenDifferenceInAllocations() { 2, Map.of("n_1", 1), 0, + null, 0, 0 ); @@ -445,7 +485,7 @@ public void testCompareTo_GivenDifferenceInMemory() { Node n = new Node("n_1", ByteSizeValue.ofMb(300).getBytes(), 5); { - Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 3, 2, Map.of("n_1", 1), 0, 0, 0); + Deployment m = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 3, 2, Map.of("n_1", 1), 0, null, 0, 0); AssignmentPlan.Builder builder = AssignmentPlan.builder(List.of(n), List.of(m)); builder.assignModelToNode(m, n, 2); planUsingMoreMemory = builder.build(); @@ -458,6 +498,7 @@ public void testCompareTo_GivenDifferenceInMemory() { 2, Map.of("n_1", 1), 0, + null, 0, 0 ); @@ -482,6 +523,7 @@ public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { 2, Map.of(), 0, + null, 0, 0 ); @@ -492,6 +534,7 @@ public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { 1, Map.of(), 0, + null, 0, 0 ); @@ -502,6 +545,7 @@ public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { 1, Map.of(), 0, + null, 0, 0 ); @@ -522,6 +566,7 @@ public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { 2, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); @@ -532,6 +577,7 @@ public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { 1, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); @@ -542,6 +588,7 @@ public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { 1, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); @@ -558,9 +605,9 @@ public void testSatisfiesAllModels_GivenAllModelsAreSatisfied() { public void testSatisfiesAllModels_GivenOneModelHasOneAllocationLess() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) .assignModelToNode(deployment1, node1, 1) .assignModelToNode(deployment2, node2, 2) @@ -573,9 +620,9 @@ public void testSatisfiesAllModels_GivenOneModelHasOneAllocationLess() { public void testArePreviouslyAssignedModelsAssigned_GivenTrue() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, null, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(20).getBytes(), 4, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) .assignModelToNode(deployment1, node1, 1) .assignModelToNode(deployment2, node2, 1) @@ -586,8 +633,8 @@ public void testArePreviouslyAssignedModelsAssigned_GivenTrue() { public void testArePreviouslyAssignedModelsAssigned_GivenFalse() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 4, null, 0, 0); AssignmentPlan plan = AssignmentPlan.builder(List.of(node1, node2), List.of(deployment1, deployment2)) .assignModelToNode(deployment1, node1, 1) .build(); @@ -597,7 +644,7 @@ public void testArePreviouslyAssignedModelsAssigned_GivenFalse() { public void testCountPreviouslyAssignedThatAreStillAssigned() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, 0, 0); + Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 3, null, 0, 0); AssignmentPlan.Deployment deployment2 = new AssignmentPlan.Deployment( "m_2", ByteSizeValue.ofMb(30).getBytes(), @@ -605,6 +652,7 @@ public void testCountPreviouslyAssignedThatAreStillAssigned() { 1, Map.of(), 4, + null, 0, 0 ); @@ -615,6 +663,7 @@ public void testCountPreviouslyAssignedThatAreStillAssigned() { 1, Map.of(), 1, + null, 0, 0 ); @@ -625,6 +674,7 @@ public void testCountPreviouslyAssignedThatAreStillAssigned() { 1, Map.of(), 0, + null, 0, 0 ); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java index bc94144bce1c5..ef76c388b81a1 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/AssignmentPlannerTests.java @@ -42,7 +42,7 @@ private static long scaleNodeSize(long nodeMemory) { public void testModelThatDoesNotFitInMemory() { { // Without perDeploymentMemory and perAllocationMemory specified List nodes = List.of(new Node("n_1", scaleNodeSize(50), 4)); - Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(51).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(51).getBytes(), 4, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan(); assertThat(plan.assignments(deployment), isEmpty()); } @@ -55,6 +55,7 @@ public void testModelThatDoesNotFitInMemory() { 1, Map.of(), 0, + null, ByteSizeValue.ofMb(250).getBytes(), ByteSizeValue.ofMb(51).getBytes() ); @@ -65,7 +66,7 @@ public void testModelThatDoesNotFitInMemory() { public void testModelWithThreadsPerAllocationNotFittingOnAnyNode() { List nodes = List.of(new Node("n_1", scaleNodeSize(100), 4), new Node("n_2", scaleNodeSize(100), 5)); - Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(1).getBytes(), 1, 6, Map.of(), 0, 0, 0); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(1).getBytes(), 1, 6, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new AssignmentPlanner(nodes, List.of(deployment)).computePlan(); assertThat(plan.assignments(deployment), isEmpty()); } @@ -73,13 +74,13 @@ public void testModelWithThreadsPerAllocationNotFittingOnAnyNode() { public void testSingleModelThatFitsFullyOnSingleNode() { { Node node = new Node("n_1", scaleNodeSize(100), 4); - Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertModelFullyAssignedToNode(plan, deployment, node); } { Node node = new Node("n_1", scaleNodeSize(1000), 8); - Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(1000).getBytes(), 8, 1, Map.of(), 0, 0, 0); + Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(1000).getBytes(), 8, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertModelFullyAssignedToNode(plan, deployment, node); } @@ -92,6 +93,7 @@ public void testSingleModelThatFitsFullyOnSingleNode() { 16, Map.of(), 0, + null, 0, 0 ); @@ -100,7 +102,7 @@ public void testSingleModelThatFitsFullyOnSingleNode() { } { Node node = new Node("n_1", scaleNodeSize(100), 4); - Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node), List.of(deployment)).computePlan(); assertModelFullyAssignedToNode(plan, deployment, node); } @@ -116,6 +118,7 @@ public void testSingleModelThatFitsFullyOnSingleNode_NewMemoryFields() { 1, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(100).getBytes() ); @@ -131,6 +134,7 @@ public void testSingleModelThatFitsFullyOnSingleNode_NewMemoryFields() { 1, Map.of(), 0, + null, ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofMb(100).getBytes() ); @@ -142,7 +146,7 @@ public void testSingleModelThatFitsFullyOnSingleNode_NewMemoryFields() { public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFullyAssignedOnOneNode() { Node node1 = new Node("n_1", scaleNodeSize(100), 4); Node node2 = new Node("n_2", scaleNodeSize(100), 4); - AssignmentPlan.Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 4, 1, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(100).getBytes(), 4, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment)).computePlan(); @@ -164,6 +168,7 @@ public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFully 1, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(150).getBytes() ); @@ -179,7 +184,7 @@ public void testSingleModelThatFitsFullyOnSingleNode_GivenTwoNodes_ShouldBeFully } public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerAllocation() { - AssignmentPlan.Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 10, 1, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 10, 1, Map.of(), 0, null, 0, 0); // Single node { Node node = new Node("n_1", scaleNodeSize(100), 4); @@ -220,6 +225,7 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenSingleThreadPerA 1, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(100).getBytes() ); @@ -260,10 +266,10 @@ public void testMultipleModelsAndNodesWithSingleSolution() { Node node2 = new Node("n_2", 2 * scaleNodeSize(50), 7); Node node3 = new Node("n_3", 2 * scaleNodeSize(50), 2); Node node4 = new Node("n_4", 2 * scaleNodeSize(50), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 4, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 2, 3, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, 0, 0); - Deployment deployment4 = new Deployment("m_4", ByteSizeValue.ofMb(50).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(50).getBytes(), 2, 4, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 2, 3, Map.of(), 0, null, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(50).getBytes(), 1, 2, Map.of(), 0, null, 0, 0); + Deployment deployment4 = new Deployment("m_4", ByteSizeValue.ofMb(50).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new AssignmentPlanner( List.of(node1, node2, node3, node4), @@ -322,6 +328,7 @@ public void testMultipleModelsAndNodesWithSingleSolution_NewMemoryFields() { 4, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(50).getBytes() ); @@ -332,6 +339,7 @@ public void testMultipleModelsAndNodesWithSingleSolution_NewMemoryFields() { 3, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(50).getBytes() ); @@ -342,6 +350,7 @@ public void testMultipleModelsAndNodesWithSingleSolution_NewMemoryFields() { 2, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(50).getBytes() ); @@ -352,6 +361,7 @@ public void testMultipleModelsAndNodesWithSingleSolution_NewMemoryFields() { 1, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(50).getBytes() ); @@ -402,7 +412,7 @@ public void testMultipleModelsAndNodesWithSingleSolution_NewMemoryFields() { } public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerAllocation() { - Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 10, 3, Map.of(), 0, 0, 0); + Deployment deployment = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 10, 3, Map.of(), 0, null, 0, 0); // Single node { Node node = new Node("n_1", scaleNodeSize(100), 4); @@ -443,6 +453,7 @@ public void testModelWithMoreAllocationsThanAvailableCores_GivenThreeThreadsPerA 3, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(50).getBytes() ); @@ -487,6 +498,7 @@ public void testModelWithPreviousAssignmentAndNoMoreCoresAvailable() { 1, Map.of("n_1", 4), 0, + null, 0, 0 ); @@ -506,18 +518,18 @@ public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation() { new Node("n_6", ByteSizeValue.ofGb(32).getBytes(), 16) ); List deployments = List.of( - new Deployment("m_1", ByteSizeValue.ofGb(4).getBytes(), 10, 1, Map.of("n_1", 5), 0, 0, 0), - new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of("n_3", 2), 0, 0, 0), - new AssignmentPlan.Deployment("m_3", ByteSizeValue.ofGb(3).getBytes(), 3, 1, Map.of(), 0, 0, 0), - new Deployment("m_4", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of("n_3", 2), 0, 0, 0), - new Deployment("m_5", ByteSizeValue.ofGb(6).getBytes(), 2, 1, Map.of(), 0, 0, 0), - new Deployment("m_6", ByteSizeValue.ofGb(1).getBytes(), 12, 1, Map.of(), 0, 0, 0), - new AssignmentPlan.Deployment("m_7", ByteSizeValue.ofGb(1).getBytes() / 2, 12, 1, Map.of("n_2", 6), 0, 0, 0), - new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, 0, 0), - new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, 0, 0), - new AssignmentPlan.Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, 0, 0), - new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, 0, 0), - new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, 0, 0) + new Deployment("m_1", ByteSizeValue.ofGb(4).getBytes(), 10, 1, Map.of("n_1", 5), 0, null, 0, 0), + new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of("n_3", 2), 0, null, 0, 0), + new AssignmentPlan.Deployment("m_3", ByteSizeValue.ofGb(3).getBytes(), 3, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_4", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of("n_3", 2), 0, null, 0, 0), + new Deployment("m_5", ByteSizeValue.ofGb(6).getBytes(), 2, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_6", ByteSizeValue.ofGb(1).getBytes(), 12, 1, Map.of(), 0, null, 0, 0), + new AssignmentPlan.Deployment("m_7", ByteSizeValue.ofGb(1).getBytes() / 2, 12, 1, Map.of("n_2", 6), 0, null, 0, 0), + new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, null, 0, 0), + new AssignmentPlan.Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, null, 0, 0) ); AssignmentPlan assignmentPlan = new AssignmentPlanner(nodes, deployments).computePlan(); @@ -550,10 +562,11 @@ public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation_New 1, Map.of("n_1", 5), 0, + null, ByteSizeValue.ofMb(400).getBytes(), ByteSizeValue.ofMb(100).getBytes() ), - new Deployment("m_2", ByteSizeValue.ofMb(100).getBytes(), 3, 1, Map.of("n_3", 2), 0, 0, 0), + new Deployment("m_2", ByteSizeValue.ofMb(100).getBytes(), 3, 1, Map.of("n_3", 2), 0, null, 0, 0), new Deployment( "m_3", ByteSizeValue.ofMb(50).getBytes(), @@ -561,6 +574,7 @@ public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation_New 1, Map.of(), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(50).getBytes() ), @@ -571,6 +585,7 @@ public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation_New 1, Map.of("n_3", 2), 0, + null, ByteSizeValue.ofMb(400).getBytes(), ByteSizeValue.ofMb(100).getBytes() ), @@ -581,6 +596,7 @@ public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation_New 1, Map.of(), 0, + null, ByteSizeValue.ofMb(800).getBytes(), ByteSizeValue.ofMb(100).getBytes() ), @@ -591,6 +607,7 @@ public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation_New 1, Map.of(), 0, + null, ByteSizeValue.ofMb(50).getBytes(), ByteSizeValue.ofMb(20).getBytes() ), @@ -601,14 +618,15 @@ public void testFullCoreUtilization_GivenModelsWithSingleThreadPerAllocation_New 1, Map.of("n_2", 6), 0, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(50).getBytes() ), - new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, 0, 0), - new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, 0, 0), - new Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, 0, 0), - new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, 0, 0), - new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, 0, 0) + new Deployment("m_8", ByteSizeValue.ofGb(2).getBytes(), 4, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_9", ByteSizeValue.ofGb(1).getBytes(), 4, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_10", ByteSizeValue.ofGb(7).getBytes(), 7, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_11", ByteSizeValue.ofGb(2).getBytes(), 3, 1, Map.of(), 0, null, 0, 0), + new Deployment("m_12", ByteSizeValue.ofGb(1).getBytes(), 10, 1, Map.of(), 0, null, 0, 0) ); AssignmentPlan assignmentPlan = new AssignmentPlanner(nodes, deployments).computePlan(); @@ -718,6 +736,7 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode m.threadsPerAllocation(), previousAssignments, 0, + null, 0, 0 ) @@ -741,10 +760,11 @@ public void testGivenLargerModelWithPreviousAssignmentsAndSmallerModelWithoutAss 1, Map.of("n_1", 2, "n_2", 1), 0, + null, 0, 0 ); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 2, 1, Map.of(), 0, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2, node3), List.of(deployment1, deployment2)) .computePlan(); assertThat(assignmentPlan.getRemainingNodeMemory("n_1"), greaterThanOrEqualTo(0L)); @@ -776,6 +796,7 @@ public void testModelWithoutCurrentAllocationsGetsAssignedIfAllocatedPreviously( 1, Map.of("n_1", 2, "n_2", 1), 3, + null, 0, 0 ); @@ -786,6 +807,7 @@ public void testModelWithoutCurrentAllocationsGetsAssignedIfAllocatedPreviously( 2, Map.of(), 1, + null, 0, 0 ); @@ -807,8 +829,8 @@ public void testModelWithoutCurrentAllocationsGetsAssignedIfAllocatedPreviously( public void testGivenPreviouslyAssignedModels_CannotAllBeAllocated() { Node node1 = new Node("n_1", scaleNodeSize(ByteSizeValue.ofGb(2).getMb()), 2); - AssignmentPlan.Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(1200).getBytes(), 1, 1, Map.of(), 1, 0, 0); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 1, 1, Map.of(), 1, 0, 0); + AssignmentPlan.Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(1200).getBytes(), 1, 1, Map.of(), 1, null, 0, 0); + AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(1100).getBytes(), 1, 1, Map.of(), 1, null, 0, 0); AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1), List.of(deployment1, deployment2)).computePlan(); @@ -818,9 +840,9 @@ public void testGivenPreviouslyAssignedModels_CannotAllBeAllocated() { public void testGivenClusterResize_AllocationShouldNotExceedMemoryConstraints() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(1840).getBytes(), 2); Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new AssignmentPlan.Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, null, 0, 0); // First only start m_1 AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1)).computePlan(); @@ -860,9 +882,9 @@ public void testGivenClusterResize_AllocationShouldNotExceedMemoryConstraints() public void testGivenClusterResize_ShouldAllocateEachModelAtLeastOnce() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(2600).getBytes(), 2); Node node2 = new Node("n_2", ByteSizeValue.ofMb(2600).getBytes(), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, null, 0, 0); // First only start m_1 AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1)).computePlan(); @@ -931,9 +953,9 @@ public void testGivenClusterResize_ShouldRemoveAllocatedModels() { // Ensure that plan is removing previously allocated models if not enough memory is available Node node1 = new Node("n_1", ByteSizeValue.ofMb(1840).getBytes(), 2); Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); // Create a plan where all deployments are assigned at least once AssignmentPlan assignmentPlan = new AssignmentPlanner(List.of(node1, node2), List.of(deployment1, deployment2, deployment3)) @@ -965,6 +987,7 @@ public void testGivenClusterResize_ShouldRemoveAllocatedModels_NewMemoryFields() 1, Map.of(), 0, + null, ByteSizeValue.ofMb(400).getBytes(), ByteSizeValue.ofMb(100).getBytes() ); @@ -975,6 +998,7 @@ public void testGivenClusterResize_ShouldRemoveAllocatedModels_NewMemoryFields() 1, Map.of(), 0, + null, ByteSizeValue.ofMb(400).getBytes(), ByteSizeValue.ofMb(150).getBytes() ); @@ -985,6 +1009,7 @@ public void testGivenClusterResize_ShouldRemoveAllocatedModels_NewMemoryFields() 1, Map.of(), 0, + null, ByteSizeValue.ofMb(250).getBytes(), ByteSizeValue.ofMb(50).getBytes() ); @@ -1028,6 +1053,7 @@ public static List createModelsFromPlan(AssignmentPlan plan) { m.threadsPerAllocation(), currentAllocations, Math.max(m.maxAssignedAllocations(), totalAllocations), + null, 0, 0 ) @@ -1096,6 +1122,7 @@ public static Deployment randomModel(String idSuffix) { randomIntBetween(1, 4), Map.of(), 0, + null, 0, 0 ); @@ -1107,6 +1134,7 @@ public static Deployment randomModel(String idSuffix) { randomIntBetween(1, 4), Map.of(), 0, + null, randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(1).getBytes()), randomLongBetween(ByteSizeValue.ofMb(100).getBytes(), ByteSizeValue.ofGb(1).getBytes()) ); @@ -1137,7 +1165,7 @@ private void runTooManyNodesAndModels(int nodesSize, int modelsSize) { } List deployments = new ArrayList<>(); for (int i = 0; i < modelsSize; i++) { - deployments.add(new Deployment("m_" + i, ByteSizeValue.ofMb(200).getBytes(), 2, 1, Map.of(), 0, 0, 0)); + deployments.add(new Deployment("m_" + i, ByteSizeValue.ofMb(200).getBytes(), 2, 1, Map.of(), 0, null, 0, 0)); } // Check plan is computed without OOM exception diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java index 7f83df5835494..9885c4d583198 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveAllAllocationsTests.java @@ -25,8 +25,8 @@ public class PreserveAllAllocationsTests extends ESTestCase { public void testGivenNoPreviousAssignments() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, null, 0, 0); PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations( List.of(node1, node2), List.of(deployment1, deployment2) @@ -45,10 +45,21 @@ public void testGivenPreviousAssignments() { 1, Map.of("n_1", 1), 1, + null, + 0, + 0 + ); + Deployment deployment2 = new Deployment( + "m_2", + ByteSizeValue.ofMb(50).getBytes(), + 6, + 4, + Map.of("n_1", 1, "n_2", 2), + 3, + null, 0, 0 ); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 6, 4, Map.of("n_1", 1, "n_2", 2), 3, 0, 0); PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations( List.of(node1, node2), List.of(deployment1, deployment2) @@ -117,6 +128,7 @@ public void testGivenPreviousAssignments() { 1, Map.of("n_1", 1), 1, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); @@ -127,6 +139,7 @@ public void testGivenPreviousAssignments() { 4, Map.of("n_1", 1, "n_2", 2), 3, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); @@ -195,7 +208,7 @@ public void testGivenPreviousAssignments() { public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments() { Node node = new Node("n_1", ByteSizeValue.ofMb(400).getBytes(), 4); - Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, 0, 0); + Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, null, 0, 0); PreserveAllAllocations preserveAllAllocations = new PreserveAllAllocations(List.of(node), List.of(deployment)); AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build(); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java index d2907eb31160b..50ba8763c690d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/PreserveOneAllocationTests.java @@ -26,8 +26,8 @@ public class PreserveOneAllocationTests extends ESTestCase { public void testGivenNoPreviousAssignments() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(440).getBytes(), 4); Node node2 = new Node("n_2", ByteSizeValue.ofMb(440).getBytes(), 4); - Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, 0, 0); - AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, 0, 0); + Deployment deployment1 = new AssignmentPlan.Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); + AssignmentPlan.Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 2, 4, Map.of(), 0, null, 0, 0); PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node1, node2), List.of(deployment1, deployment2)); List nodesPreservingAllocations = preserveOneAllocation.nodesPreservingAllocations(); @@ -42,8 +42,18 @@ public void testGivenPreviousAssignments() { // old memory format Node node1 = new Node("n_1", ByteSizeValue.ofMb(640).getBytes(), 8); Node node2 = new Node("n_2", ByteSizeValue.ofMb(640).getBytes(), 8); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of("n_1", 1), 1, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(50).getBytes(), 6, 4, Map.of("n_1", 1, "n_2", 2), 3, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 1, Map.of("n_1", 1), 1, null, 0, 0); + Deployment deployment2 = new Deployment( + "m_2", + ByteSizeValue.ofMb(50).getBytes(), + 6, + 4, + Map.of("n_1", 1, "n_2", 2), + 3, + null, + 0, + 0 + ); PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation( List.of(node1, node2), List.of(deployment1, deployment2) @@ -117,6 +127,7 @@ public void testGivenPreviousAssignments() { 1, Map.of("n_1", 1), 1, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); @@ -127,6 +138,7 @@ public void testGivenPreviousAssignments() { 4, Map.of("n_1", 1, "n_2", 2), 3, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); @@ -199,7 +211,7 @@ public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments { // old memory format Node node = new Node("n_1", ByteSizeValue.ofMb(400).getBytes(), 4); - Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, 0, 0); + Deployment deployment = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 2, 2, Map.of("n_1", 2), 2, null, 0, 0); PreserveOneAllocation preserveOneAllocation = new PreserveOneAllocation(List.of(node), List.of(deployment)); AssignmentPlan plan = AssignmentPlan.builder(List.of(node), List.of(deployment)).build(); @@ -221,6 +233,7 @@ public void testGivenModelWithPreviousAssignments_AndPlanToMergeHasNoAssignments 2, Map.of("n_1", 2), 2, + null, ByteSizeValue.ofMb(300).getBytes(), ByteSizeValue.ofMb(10).getBytes() ); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java index 651e4764cb894..4993600d0d3b3 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/assignment/planning/ZoneAwareAssignmentPlannerTests.java @@ -36,7 +36,7 @@ public class ZoneAwareAssignmentPlannerTests extends ESTestCase { public void testGivenOneModel_OneNode_OneZone_DoesNotFit() { Node node = new Node("n_1", 100, 1); - AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 2, Map.of(), 0, 0, 0); + AssignmentPlan.Deployment deployment = new AssignmentPlan.Deployment("m_1", 100, 1, 2, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node)), List.of(deployment)).computePlan(); @@ -52,6 +52,7 @@ public void testGivenOneModel_OneNode_OneZone_FullyFits() { 2, Map.of(), 0, + null, 0, 0 ); @@ -70,6 +71,7 @@ public void testGivenOneModel_OneNode_OneZone_PartiallyFits() { 2, Map.of(), 0, + null, 0, 0 ); @@ -91,6 +93,7 @@ public void testGivenOneModelWithSingleAllocation_OneNode_TwoZones() { 2, Map.of(), 0, + null, 0, 0 ); @@ -118,6 +121,7 @@ public void testGivenOneModel_OneNodePerZone_TwoZones_FullyFits() { 2, Map.of(), 0, + null, 0, 0 ); @@ -144,6 +148,7 @@ public void testGivenOneModel_OneNodePerZone_TwoZones_PartiallyFits() { 3, Map.of(), 0, + null, 0, 0 ); @@ -168,9 +173,9 @@ public void testGivenThreeModels_TwoNodesPerZone_ThreeZones_FullyFit() { Node node4 = new Node("n_4", ByteSizeValue.ofMb(1000).getBytes(), 4); Node node5 = new Node("n_5", ByteSizeValue.ofMb(1000).getBytes(), 4); Node node6 = new Node("n_6", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 4, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 6, 2, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(30).getBytes(), 2, 3, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 4, 1, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 6, 2, Map.of(), 0, null, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(30).getBytes(), 2, 3, Map.of(), 0, null, 0, 0); Map, List> nodesByZone = Map.of( List.of("z_1"), @@ -216,8 +221,8 @@ public void testGivenTwoModelsWithSingleAllocation_OneNode_ThreeZones() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(1000).getBytes(), 4); Node node2 = new Node("n_2", ByteSizeValue.ofMb(1000).getBytes(), 4); Node node3 = new Node("n_3", ByteSizeValue.ofMb(1000).getBytes(), 4); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(30).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); AssignmentPlan plan = new ZoneAwareAssignmentPlanner( Map.of(List.of("z1"), List.of(node1), List.of("z2"), List.of(node2), List.of("z3"), List.of(node3)), @@ -255,6 +260,7 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode m.threadsPerAllocation(), previousAssignments, 0, + null, 0, 0 ) @@ -270,9 +276,9 @@ public void testPreviousAssignmentsGetAtLeastAsManyAllocationsAfterAddingNewMode public void testGivenClusterResize_GivenOneZone_ShouldAllocateEachModelAtLeastOnce() { Node node1 = new Node("n_1", ByteSizeValue.ofMb(2580).getBytes(), 2); Node node2 = new Node("n_2", ByteSizeValue.ofMb(2580).getBytes(), 2); - Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, 0, 0); - Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, 0, 0); - Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, 0, 0); + Deployment deployment1 = new Deployment("m_1", ByteSizeValue.ofMb(800).getBytes(), 2, 1, Map.of(), 0, null, 0, 0); + Deployment deployment2 = new Deployment("m_2", ByteSizeValue.ofMb(800).getBytes(), 1, 1, Map.of(), 0, null, 0, 0); + Deployment deployment3 = new Deployment("m_3", ByteSizeValue.ofMb(250).getBytes(), 4, 1, Map.of(), 0, null, 0, 0); // First only start m_1 AssignmentPlan assignmentPlan = new ZoneAwareAssignmentPlanner(Map.of(List.of(), List.of(node1, node2)), List.of(deployment1)) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java index 860da3140f4fe..7eb9d7e940dda 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/inference/pytorch/process/PyTorchResultProcessorTests.java @@ -276,10 +276,15 @@ public void testsTimeDependentStats() { var timeSupplier = new TimeSupplier(resultTimestamps); var processor = new PyTorchResultProcessor("foo", s -> {}, timeSupplier); + for (int i = 0; i < 10; i++) { + processor.registerRequest("foo" + i, ActionListener.noop()); + } + // 1st period - processor.processInferenceResult(wrapInferenceResult("foo", false, 200L)); - processor.processInferenceResult(wrapInferenceResult("foo", false, 200L)); - processor.processInferenceResult(wrapInferenceResult("foo", false, 200L)); + processor.processInferenceResult(wrapInferenceResult("foo0", false, 200L)); + processor.processInferenceResult(wrapInferenceResult("foo1", false, 200L)); + processor.processInferenceResult(wrapInferenceResult("foo2", false, 200L)); + // first call has no results as is in the same period var stats = processor.getResultStats(); assertThat(stats.recentStats().requestsProcessed(), equalTo(0L)); @@ -293,7 +298,7 @@ public void testsTimeDependentStats() { assertThat(stats.peakThroughput(), equalTo(3L)); // 2nd period - processor.processInferenceResult(wrapInferenceResult("foo", false, 100L)); + processor.processInferenceResult(wrapInferenceResult("foo3", false, 100L)); stats = processor.getResultStats(); assertNotNull(stats.recentStats()); assertThat(stats.recentStats().requestsProcessed(), equalTo(1L)); @@ -305,7 +310,7 @@ public void testsTimeDependentStats() { assertThat(stats.recentStats().requestsProcessed(), equalTo(0L)); // 4th period - processor.processInferenceResult(wrapInferenceResult("foo", false, 300L)); + processor.processInferenceResult(wrapInferenceResult("foo4", false, 300L)); stats = processor.getResultStats(); assertNotNull(stats.recentStats()); assertThat(stats.recentStats().requestsProcessed(), equalTo(1L)); @@ -313,8 +318,8 @@ public void testsTimeDependentStats() { assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[9]))); // 7th period - processor.processInferenceResult(wrapInferenceResult("foo", false, 410L)); - processor.processInferenceResult(wrapInferenceResult("foo", false, 390L)); + processor.processInferenceResult(wrapInferenceResult("foo5", false, 410L)); + processor.processInferenceResult(wrapInferenceResult("foo6", false, 390L)); stats = processor.getResultStats(); assertThat(stats.recentStats().requestsProcessed(), equalTo(0L)); assertThat(stats.recentStats().avgInferenceTime(), nullValue()); @@ -325,9 +330,9 @@ public void testsTimeDependentStats() { assertThat(stats.lastUsed(), equalTo(Instant.ofEpochMilli(resultTimestamps[12]))); // 8th period - processor.processInferenceResult(wrapInferenceResult("foo", false, 510L)); - processor.processInferenceResult(wrapInferenceResult("foo", false, 500L)); - processor.processInferenceResult(wrapInferenceResult("foo", false, 490L)); + processor.processInferenceResult(wrapInferenceResult("foo7", false, 510L)); + processor.processInferenceResult(wrapInferenceResult("foo8", false, 500L)); + processor.processInferenceResult(wrapInferenceResult("foo9", false, 490L)); stats = processor.getResultStats(); assertNotNull(stats.recentStats()); assertThat(stats.recentStats().requestsProcessed(), equalTo(3L)); diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java index fef9b07429702..c3ad54427f70c 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/job/NodeLoadDetectorTests.java @@ -133,7 +133,8 @@ public void testNodeLoadDetection() { Priority.NORMAL, 0L, 0L - ) + ), + null ) .addRoutingEntry("_node_id4", new RoutingInfo(1, 1, RoutingState.STARTING, "")) .addRoutingEntry("_node_id2", new RoutingInfo(1, 1, RoutingState.FAILED, "test")) diff --git a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/rest/inference/RestUpdateTrainedModelDeploymentActionTests.java b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/rest/inference/RestUpdateTrainedModelDeploymentActionTests.java index 2bb10d66d3d58..cce6b284a524d 100644 --- a/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/rest/inference/RestUpdateTrainedModelDeploymentActionTests.java +++ b/x-pack/plugin/ml/src/test/java/org/elasticsearch/xpack/ml/rest/inference/RestUpdateTrainedModelDeploymentActionTests.java @@ -30,7 +30,7 @@ public void testNumberOfAllocationInParam() { assertThat(actionRequest, instanceOf(UpdateTrainedModelDeploymentAction.Request.class)); var request = (UpdateTrainedModelDeploymentAction.Request) actionRequest; - assertEquals(request.getNumberOfAllocations(), 5); + assertEquals(request.getNumberOfAllocations().intValue(), 5); executeCalled.set(true); return mock(CreateTrainedModelAssignmentAction.Response.class); @@ -53,7 +53,7 @@ public void testNumberOfAllocationInBody() { assertThat(actionRequest, instanceOf(UpdateTrainedModelDeploymentAction.Request.class)); var request = (UpdateTrainedModelDeploymentAction.Request) actionRequest; - assertEquals(request.getNumberOfAllocations(), 6); + assertEquals(request.getNumberOfAllocations().intValue(), 6); executeCalled.set(true); return mock(CreateTrainedModelAssignmentAction.Response.class); diff --git a/x-pack/plugin/monitoring/src/main/java/org/elasticsearch/xpack/monitoring/MonitoringTemplateRegistry.java b/x-pack/plugin/monitoring/src/main/java/org/elasticsearch/xpack/monitoring/MonitoringTemplateRegistry.java index 12eeaf8732235..e0433ea6fdd71 100644 --- a/x-pack/plugin/monitoring/src/main/java/org/elasticsearch/xpack/monitoring/MonitoringTemplateRegistry.java +++ b/x-pack/plugin/monitoring/src/main/java/org/elasticsearch/xpack/monitoring/MonitoringTemplateRegistry.java @@ -77,7 +77,7 @@ public class MonitoringTemplateRegistry extends IndexTemplateRegistry { * writes monitoring data in ECS format as of 8.0. These templates define the ECS schema as well as alias fields for the old monitoring * mappings that point to the corresponding ECS fields. */ - public static final int STACK_MONITORING_REGISTRY_VERSION = 8_00_00_99 + 17; + public static final int STACK_MONITORING_REGISTRY_VERSION = 8_00_00_99 + 18; private static final String STACK_MONITORING_REGISTRY_VERSION_VARIABLE = "xpack.stack.monitoring.template.release.version"; private static final String STACK_TEMPLATE_VERSION = "8"; private static final String STACK_TEMPLATE_VERSION_VARIABLE = "xpack.stack.monitoring.template.version"; diff --git a/x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/action/TransportRollupSearchAction.java b/x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/action/TransportRollupSearchAction.java index 6bd29ddb52301..4108b0f6d3c83 100644 --- a/x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/action/TransportRollupSearchAction.java +++ b/x-pack/plugin/rollup/src/main/java/org/elasticsearch/xpack/rollup/action/TransportRollupSearchAction.java @@ -128,7 +128,8 @@ public AggregationReduceContext forPartialReduction() { bigArrays, scriptService, ((CancellableTask) task)::isCancelled, - request.source().aggregations() + request.source().aggregations(), + b -> {} ); } diff --git a/x-pack/plugin/rollup/src/test/java/org/elasticsearch/xpack/rollup/job/RollupIndexerStateTests.java b/x-pack/plugin/rollup/src/test/java/org/elasticsearch/xpack/rollup/job/RollupIndexerStateTests.java index 105711c4057a6..7a947fcb5ce02 100644 --- a/x-pack/plugin/rollup/src/test/java/org/elasticsearch/xpack/rollup/job/RollupIndexerStateTests.java +++ b/x-pack/plugin/rollup/src/test/java/org/elasticsearch/xpack/rollup/job/RollupIndexerStateTests.java @@ -556,7 +556,7 @@ public void testMultipleJobTriggering() throws Exception { assertThat(indexer.getState(), equalTo(IndexerState.STARTED)); // This may take more than one attempt due to a cleanup/transition phase // that happens after state change to STARTED (`isJobFinishing`). - assertBusy(() -> indexer.maybeTriggerAsyncJob(System.currentTimeMillis())); + assertBusy(() -> assertTrue(indexer.maybeTriggerAsyncJob(System.currentTimeMillis()))); assertThat(indexer.getState(), equalTo(IndexerState.INDEXING)); assertFalse(indexer.maybeTriggerAsyncJob(System.currentTimeMillis())); assertThat(indexer.getState(), equalTo(IndexerState.INDEXING)); @@ -566,7 +566,7 @@ public void testMultipleJobTriggering() throws Exception { assertThat(indexer.getStats().getNumPages(), equalTo((long) i + 1)); } final CountDownLatch latch = indexer.newLatch(); - assertBusy(() -> indexer.maybeTriggerAsyncJob(System.currentTimeMillis())); + assertBusy(() -> assertTrue(indexer.maybeTriggerAsyncJob(System.currentTimeMillis()))); assertThat(indexer.stop(), equalTo(IndexerState.STOPPING)); assertThat(indexer.getState(), Matchers.either(Matchers.is(IndexerState.STOPPING)).or(Matchers.is(IndexerState.STOPPED))); latch.countDown(); diff --git a/x-pack/plugin/security/cli/src/main/java/org/elasticsearch/xpack/security/cli/AutoConfigureNode.java b/x-pack/plugin/security/cli/src/main/java/org/elasticsearch/xpack/security/cli/AutoConfigureNode.java index 29828fba085d8..3994fb50c7fc6 100644 --- a/x-pack/plugin/security/cli/src/main/java/org/elasticsearch/xpack/security/cli/AutoConfigureNode.java +++ b/x-pack/plugin/security/cli/src/main/java/org/elasticsearch/xpack/security/cli/AutoConfigureNode.java @@ -114,7 +114,8 @@ */ public class AutoConfigureNode extends EnvironmentAwareCommand { - public static final String AUTO_CONFIG_ALT_DN = "CN=Elasticsearch security auto-configuration HTTP CA"; + public static final String AUTO_CONFIG_HTTP_ALT_DN = "CN=Elasticsearch security auto-configuration HTTP CA"; + public static final String AUTO_CONFIG_TRANSPORT_ALT_DN = "CN=Elasticsearch security auto-configuration transport CA"; // the transport keystore is also used as a truststore private static final String SIGNATURE_ALGORITHM = "SHA256withRSA"; private static final String TRANSPORT_AUTOGENERATED_KEYSTORE_NAME = "transport"; @@ -272,7 +273,8 @@ public void execute(Terminal terminal, OptionSet options, Environment env, Proce final List transportAddresses; final String cnValue = NODE_NAME_SETTING.exists(env.settings()) ? NODE_NAME_SETTING.get(env.settings()) : System.getenv("HOSTNAME"); final X500Principal certificatePrincipal = new X500Principal("CN=" + cnValue); - final X500Principal caPrincipal = new X500Principal(AUTO_CONFIG_ALT_DN); + final X500Principal httpCaPrincipal = new X500Principal(AUTO_CONFIG_HTTP_ALT_DN); + final X500Principal transportCaPrincipal = new X500Principal(AUTO_CONFIG_TRANSPORT_ALT_DN); if (inEnrollmentMode) { // this is an enrolling node, get HTTP CA key/certificate and transport layer key/certificate from another node @@ -402,7 +404,7 @@ public void execute(Terminal terminal, OptionSet options, Environment env, Proce final KeyPair transportCaKeyPair = CertGenUtils.generateKeyPair(TRANSPORT_CA_KEY_SIZE); final PrivateKey transportCaKey = transportCaKeyPair.getPrivate(); transportCaCert = CertGenUtils.generateSignedCertificate( - caPrincipal, + transportCaPrincipal, null, transportCaKeyPair, null, @@ -429,7 +431,7 @@ public void execute(Terminal terminal, OptionSet options, Environment env, Proce httpCaKey = httpCaKeyPair.getPrivate(); // self-signed CA httpCaCert = CertGenUtils.generateSignedCertificate( - caPrincipal, + httpCaPrincipal, null, httpCaKeyPair, null, diff --git a/x-pack/plugin/security/cli/src/test/java/org/elasticsearch/xpack/security/cli/AutoConfigureNodeTests.java b/x-pack/plugin/security/cli/src/test/java/org/elasticsearch/xpack/security/cli/AutoConfigureNodeTests.java index d1dbe9d037756..129d85d0818b2 100644 --- a/x-pack/plugin/security/cli/src/test/java/org/elasticsearch/xpack/security/cli/AutoConfigureNodeTests.java +++ b/x-pack/plugin/security/cli/src/test/java/org/elasticsearch/xpack/security/cli/AutoConfigureNodeTests.java @@ -18,6 +18,7 @@ import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.ssl.KeyStoreUtil; import org.elasticsearch.core.IOUtils; +import org.elasticsearch.core.Tuple; import org.elasticsearch.env.Environment; import org.elasticsearch.env.TestEnvironment; import org.elasticsearch.http.HttpTransportSettings; @@ -32,6 +33,8 @@ import java.util.List; import static java.nio.file.StandardOpenOption.CREATE_NEW; +import static org.elasticsearch.xpack.security.cli.AutoConfigureNode.AUTO_CONFIG_HTTP_ALT_DN; +import static org.elasticsearch.xpack.security.cli.AutoConfigureNode.AUTO_CONFIG_TRANSPORT_ALT_DN; import static org.elasticsearch.xpack.security.cli.AutoConfigureNode.anyRemoteHostNodeAddress; import static org.elasticsearch.xpack.security.cli.AutoConfigureNode.removePreviousAutoconfiguration; import static org.hamcrest.Matchers.equalTo; @@ -131,6 +134,21 @@ public void testRemovePreviousAutoconfigurationRetainsUserAdded() throws Excepti assertEquals(file1, removePreviousAutoconfiguration(file2)); } + public void testSubjectAndIssuerForGeneratedCertificates() throws Exception { + // test no publish settings + Path tempDir = createTempDir(); + try { + Files.createDirectory(tempDir.resolve("config")); + // empty yml file, it just has to exist + Files.write(tempDir.resolve("config").resolve("elasticsearch.yml"), List.of(), CREATE_NEW); + Tuple generatedCerts = runAutoConfigAndReturnCertificates(tempDir, Settings.EMPTY); + assertThat(checkSubjectAndIssuerDN(generatedCerts.v1(), "CN=dummy.test.hostname", AUTO_CONFIG_HTTP_ALT_DN), is(true)); + assertThat(checkSubjectAndIssuerDN(generatedCerts.v2(), "CN=dummy.test.hostname", AUTO_CONFIG_TRANSPORT_ALT_DN), is(true)); + } finally { + deleteDirectory(tempDir); + } + } + public void testGeneratedHTTPCertificateSANs() throws Exception { // test no publish settings Path tempDir = createTempDir(); @@ -262,6 +280,14 @@ private boolean checkGeneralNameSan(X509Certificate certificate, String generalN return false; } + private boolean checkSubjectAndIssuerDN(X509Certificate certificate, String subjectName, String issuerName) throws Exception { + if (certificate.getSubjectX500Principal().getName().equals(subjectName) + && certificate.getIssuerX500Principal().getName().equals(issuerName)) { + return true; + } + return false; + } + private void verifyExtendedKeyUsage(X509Certificate httpCertificate) throws Exception { List extendedKeyUsage = httpCertificate.getExtendedKeyUsage(); assertEquals("Only one extended key usage expected for HTTP certificate.", 1, extendedKeyUsage.size()); @@ -270,6 +296,11 @@ private void verifyExtendedKeyUsage(X509Certificate httpCertificate) throws Exce } private X509Certificate runAutoConfigAndReturnHTTPCertificate(Path configDir, Settings settings) throws Exception { + Tuple generatedCertificates = runAutoConfigAndReturnCertificates(configDir, settings); + return generatedCertificates.v1(); + } + + private Tuple runAutoConfigAndReturnCertificates(Path configDir, Settings settings) throws Exception { final Environment env = TestEnvironment.newEnvironment(Settings.builder().put("path.home", configDir).put(settings).build()); // runs the command to auto-generate the config files and the keystore new AutoConfigureNode(false).execute(MockTerminal.create(), new OptionParser().parse(), env, null); @@ -278,16 +309,28 @@ private X509Certificate runAutoConfigAndReturnHTTPCertificate(Path configDir, Se nodeKeystore.decrypt(new char[0]); // the keystore is always bootstrapped with an empty password SecureString httpKeystorePassword = nodeKeystore.getString("xpack.security.http.ssl.keystore.secure_password"); + SecureString transportKeystorePassword = nodeKeystore.getString("xpack.security.transport.ssl.keystore.secure_password"); final Settings newSettings = Settings.builder().loadFromPath(env.configFile().resolve("elasticsearch.yml")).build(); final String httpKeystorePath = newSettings.get("xpack.security.http.ssl.keystore.path"); + final String transportKeystorePath = newSettings.get("xpack.security.transport.ssl.keystore.path"); KeyStore httpKeystore = KeyStoreUtil.readKeyStore( configDir.resolve("config").resolve(httpKeystorePath), "PKCS12", httpKeystorePassword.getChars() ); - return (X509Certificate) httpKeystore.getCertificate("http"); + + KeyStore transportKeystore = KeyStoreUtil.readKeyStore( + configDir.resolve("config").resolve(transportKeystorePath), + "PKCS12", + transportKeystorePassword.getChars() + ); + + X509Certificate httpCertificate = (X509Certificate) httpKeystore.getCertificate("http"); + X509Certificate transportCertificate = (X509Certificate) transportKeystore.getCertificate("transport"); + + return new Tuple<>(httpCertificate, transportCertificate); } private void deleteDirectory(Path directory) throws IOException { diff --git a/x-pack/plugin/security/licenses/nimbus-jose-jwt-LICENSE.txt b/x-pack/plugin/security/licenses/nimbus-jose-jwt-LICENSE.txt new file mode 100644 index 0000000000000..d645695673349 --- /dev/null +++ b/x-pack/plugin/security/licenses/nimbus-jose-jwt-LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/x-pack/plugin/core/licenses/nimbus-jose-jwt-NOTICE.txt b/x-pack/plugin/security/licenses/nimbus-jose-jwt-NOTICE.txt similarity index 100% rename from x-pack/plugin/core/licenses/nimbus-jose-jwt-NOTICE.txt rename to x-pack/plugin/security/licenses/nimbus-jose-jwt-NOTICE.txt diff --git a/x-pack/plugin/security/qa/security-trial/src/javaRestTest/java/org/elasticsearch/xpack/security/role/BulkPutRoleRestIT.java b/x-pack/plugin/security/qa/security-trial/src/javaRestTest/java/org/elasticsearch/xpack/security/role/BulkPutRoleRestIT.java index 0297abad7a508..88b952f33394e 100644 --- a/x-pack/plugin/security/qa/security-trial/src/javaRestTest/java/org/elasticsearch/xpack/security/role/BulkPutRoleRestIT.java +++ b/x-pack/plugin/security/qa/security-trial/src/javaRestTest/java/org/elasticsearch/xpack/security/role/BulkPutRoleRestIT.java @@ -181,15 +181,74 @@ public void testPutNoValidRoles() throws Exception { public void testBulkUpdates() throws Exception { String request = """ {"roles": {"test1": {"cluster": ["all"],"indices": [{"names": ["*"],"privileges": ["all"]}]}, "test2": - {"cluster": ["all"],"indices": [{"names": ["*"],"privileges": ["read"]}]}, "test3": - {"cluster": ["all"],"indices": [{"names": ["*"],"privileges": ["write"]}]}}}"""; - + {"cluster": ["all"],"indices": [{"names": ["*"],"privileges": ["read"]}], "description": "something"}, "test3": + {"cluster": ["all"],"indices": [{"names": ["*"],"privileges": ["write"]}], "remote_indices":[{"names":["logs-*"], + "privileges":["read"],"clusters":["my_cluster*","other_cluster"]}]}}}"""; { Map responseMap = upsertRoles(request); assertThat(responseMap, not(hasKey("errors"))); List> items = (List>) responseMap.get("created"); assertEquals(3, items.size()); + + fetchRoleAndAssertEqualsExpected( + "test1", + new RoleDescriptor( + "test1", + new String[] { "all" }, + new RoleDescriptor.IndicesPrivileges[] { + RoleDescriptor.IndicesPrivileges.builder().indices("*").privileges("all").build() }, + null, + null, + null, + null, + null, + null, + null, + null, + null + ) + ); + fetchRoleAndAssertEqualsExpected( + "test2", + new RoleDescriptor( + "test2", + new String[] { "all" }, + new RoleDescriptor.IndicesPrivileges[] { + RoleDescriptor.IndicesPrivileges.builder().indices("*").privileges("read").build() }, + null, + null, + null, + null, + null, + null, + null, + null, + "something" + ) + ); + fetchRoleAndAssertEqualsExpected( + "test3", + new RoleDescriptor( + "test3", + new String[] { "all" }, + new RoleDescriptor.IndicesPrivileges[] { + RoleDescriptor.IndicesPrivileges.builder().indices("*").privileges("write").build() }, + null, + null, + null, + null, + null, + new RoleDescriptor.RemoteIndicesPrivileges[] { + RoleDescriptor.RemoteIndicesPrivileges.builder("my_cluster*", "other_cluster") + .indices("logs-*") + .privileges("read") + .build() }, + null, + null, + null + ) + ); } { Map responseMap = upsertRoles(request); @@ -200,7 +259,7 @@ public void testBulkUpdates() throws Exception { } { request = """ - {"roles": {"test1": {"cluster": ["all"],"indices": [{"names": ["*"],"privileges": ["read"]}]}, "test2": + {"roles": {"test1": {}, "test2": {"cluster": ["all"],"indices": [{"names": ["*"],"privileges": ["all"]}]}, "test3": {"cluster": ["all"],"indices": [{"names": ["*"],"privileges": ["all"]}]}}}"""; @@ -208,6 +267,49 @@ public void testBulkUpdates() throws Exception { assertThat(responseMap, not(hasKey("errors"))); List> items = (List>) responseMap.get("updated"); assertEquals(3, items.size()); + + assertThat(responseMap, not(hasKey("errors"))); + + fetchRoleAndAssertEqualsExpected( + "test1", + new RoleDescriptor("test1", null, null, null, null, null, null, null, null, null, null, null) + ); + fetchRoleAndAssertEqualsExpected( + "test2", + new RoleDescriptor( + "test2", + new String[] { "all" }, + new RoleDescriptor.IndicesPrivileges[] { + RoleDescriptor.IndicesPrivileges.builder().indices("*").privileges("all").build() }, + null, + null, + null, + null, + null, + null, + null, + null, + null + ) + ); + fetchRoleAndAssertEqualsExpected( + "test3", + new RoleDescriptor( + "test3", + new String[] { "all" }, + new RoleDescriptor.IndicesPrivileges[] { + RoleDescriptor.IndicesPrivileges.builder().indices("*").privileges("all").build() }, + null, + null, + null, + null, + null, + null, + null, + null, + null + ) + ); } } } diff --git a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSingleNodeTests.java b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSingleNodeTests.java index 2ced54a513146..435706dce7019 100644 --- a/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSingleNodeTests.java +++ b/x-pack/plugin/security/src/internalClusterTest/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmSingleNodeTests.java @@ -52,7 +52,6 @@ import org.elasticsearch.xpack.core.security.action.user.AuthenticateResponse; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.Realm; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.security.LocalStateSecurity; import org.elasticsearch.xpack.security.Security; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/TransportGrantAction.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/TransportGrantAction.java index 667b513555594..fffcb476abaa4 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/TransportGrantAction.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/action/TransportGrantAction.java @@ -7,24 +7,33 @@ package org.elasticsearch.xpack.security.action; +import org.elasticsearch.ElasticsearchSecurityException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.support.ActionFilters; import org.elasticsearch.action.support.TransportAction; +import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.util.concurrent.ThreadContext; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.security.action.Grant; import org.elasticsearch.xpack.core.security.action.GrantRequest; import org.elasticsearch.xpack.core.security.action.user.AuthenticateAction; import org.elasticsearch.xpack.core.security.action.user.AuthenticateRequest; import org.elasticsearch.xpack.core.security.authc.Authentication; import org.elasticsearch.xpack.core.security.authc.AuthenticationServiceField; import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; +import org.elasticsearch.xpack.core.security.authc.support.BearerToken; +import org.elasticsearch.xpack.core.security.authc.support.UsernamePasswordToken; import org.elasticsearch.xpack.security.authc.AuthenticationService; +import org.elasticsearch.xpack.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.security.authz.AuthorizationService; +import static org.elasticsearch.xpack.core.security.action.Grant.ACCESS_TOKEN_GRANT_TYPE; +import static org.elasticsearch.xpack.core.security.action.Grant.PASSWORD_GRANT_TYPE; + public abstract class TransportGrantAction extends TransportAction< Request, Response> { @@ -50,7 +59,7 @@ public TransportGrantAction( @Override public final void doExecute(Task task, Request request, ActionListener listener) { try (ThreadContext.StoredContext ignore = threadContext.stashContext()) { - final AuthenticationToken authenticationToken = request.getGrant().getAuthenticationToken(); + final AuthenticationToken authenticationToken = getAuthenticationToken(request.getGrant()); assert authenticationToken != null : "authentication token must not be null"; final String runAsUsername = request.getGrant().getRunAsUsername(); @@ -109,4 +118,30 @@ protected abstract void doExecuteWithGrantAuthentication( Authentication authentication, ActionListener listener ); + + public static AuthenticationToken getAuthenticationToken(Grant grant) { + assert grant.validate(null) == null : "grant is invalid"; + return switch (grant.getType()) { + case PASSWORD_GRANT_TYPE -> new UsernamePasswordToken(grant.getUsername(), grant.getPassword()); + case ACCESS_TOKEN_GRANT_TYPE -> { + SecureString clientAuthentication = grant.getClientAuthentication() != null + ? grant.getClientAuthentication().value() + : null; + AuthenticationToken token = JwtAuthenticationToken.tryParseJwt(grant.getAccessToken(), clientAuthentication); + if (token != null) { + yield token; + } + if (clientAuthentication != null) { + clientAuthentication.close(); + throw new ElasticsearchSecurityException( + "[client_authentication] not supported with the supplied access_token type", + RestStatus.BAD_REQUEST + ); + } + // here we effectively assume it's an ES access token (from the {@code TokenService}) + yield new BearerToken(grant.getAccessToken()); + } + default -> throw new ElasticsearchSecurityException("the grant type [{}] is not supported", grant.getType()); + }; + } } diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java index 01104806c4a1c..bc5cc4a5e6b3f 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrail.java @@ -44,6 +44,7 @@ import org.elasticsearch.xcontent.json.JsonStringEncoder; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.core.security.SecurityContext; +import org.elasticsearch.xpack.core.security.action.ActionTypes; import org.elasticsearch.xpack.core.security.action.Grant; import org.elasticsearch.xpack.core.security.action.apikey.AbstractCreateApiKeyRequest; import org.elasticsearch.xpack.core.security.action.apikey.BaseSingleUpdateApiKeyRequest; @@ -72,6 +73,8 @@ import org.elasticsearch.xpack.core.security.action.profile.SetProfileEnabledRequest; import org.elasticsearch.xpack.core.security.action.profile.UpdateProfileDataAction; import org.elasticsearch.xpack.core.security.action.profile.UpdateProfileDataRequest; +import org.elasticsearch.xpack.core.security.action.role.BulkDeleteRolesRequest; +import org.elasticsearch.xpack.core.security.action.role.BulkPutRolesRequest; import org.elasticsearch.xpack.core.security.action.role.DeleteRoleAction; import org.elasticsearch.xpack.core.security.action.role.DeleteRoleRequest; import org.elasticsearch.xpack.core.security.action.role.PutRoleAction; @@ -291,6 +294,8 @@ public class LoggingAuditTrail implements AuditTrail, ClusterStateListener { PutUserAction.NAME, PutRoleAction.NAME, PutRoleMappingAction.NAME, + ActionTypes.BULK_PUT_ROLES.name(), + ActionTypes.BULK_DELETE_ROLES.name(), TransportSetEnabledAction.TYPE.name(), TransportChangePasswordAction.TYPE.name(), CreateApiKeyAction.NAME, @@ -731,6 +736,11 @@ public void accessGranted( } else if (msg instanceof PutRoleRequest) { assert PutRoleAction.NAME.equals(action); securityChangeLogEntryBuilder(requestId).withRequestBody((PutRoleRequest) msg).build(); + } else if (msg instanceof BulkPutRolesRequest bulkPutRolesRequest) { + assert ActionTypes.BULK_PUT_ROLES.name().equals(action); + for (RoleDescriptor roleDescriptor : bulkPutRolesRequest.getRoles()) { + securityChangeLogEntryBuilder(requestId).withRequestBody(roleDescriptor.getName(), roleDescriptor).build(); + } } else if (msg instanceof PutRoleMappingRequest) { assert PutRoleMappingAction.NAME.equals(action); securityChangeLogEntryBuilder(requestId).withRequestBody((PutRoleMappingRequest) msg).build(); @@ -755,6 +765,11 @@ public void accessGranted( } else if (msg instanceof DeleteRoleRequest) { assert DeleteRoleAction.NAME.equals(action); securityChangeLogEntryBuilder(requestId).withRequestBody((DeleteRoleRequest) msg).build(); + } else if (msg instanceof BulkDeleteRolesRequest bulkDeleteRolesRequest) { + assert ActionTypes.BULK_DELETE_ROLES.name().equals(action); + for (String roleName : bulkDeleteRolesRequest.getRoleNames()) { + securityChangeLogEntryBuilder(requestId).withDeleteRole(roleName).build(); + } } else if (msg instanceof DeleteRoleMappingRequest) { assert DeleteRoleMappingAction.NAME.equals(action); securityChangeLogEntryBuilder(requestId).withRequestBody((DeleteRoleMappingRequest) msg).build(); @@ -1160,15 +1175,19 @@ LogEntryBuilder withRequestBody(ChangePasswordRequest changePasswordRequest) thr } LogEntryBuilder withRequestBody(PutRoleRequest putRoleRequest) throws IOException { + return withRequestBody(putRoleRequest.name(), putRoleRequest.roleDescriptor()); + } + + LogEntryBuilder withRequestBody(String roleName, RoleDescriptor roleDescriptor) throws IOException { logEntry.with(EVENT_ACTION_FIELD_NAME, "put_role"); XContentBuilder builder = JsonXContent.contentBuilder().humanReadable(true); builder.startObject() .startObject("role") - .field("name", putRoleRequest.name()) + .field("name", roleName) // the "role_descriptor" nested structure, where the "name" is left out, is closer to the event structure // for creating API Keys .field("role_descriptor"); - withRoleDescriptor(builder, putRoleRequest.roleDescriptor()); + withRoleDescriptor(builder, roleDescriptor); builder.endObject() // role .endObject(); logEntry.with(PUT_CONFIG_FIELD_NAME, Strings.toString(builder)); @@ -1350,7 +1369,7 @@ private static void withRoleDescriptor(XContentBuilder builder, RoleDescriptor r withIndicesPrivileges(builder, indicesPrivileges); } builder.endArray(); - // the toXContent method of the {@code RoleDescriptor.ApplicationResourcePrivileges) does a good job + // the toXContent method of the {@code RoleDescriptor.ApplicationResourcePrivileges} does a good job builder.xContentList(RoleDescriptor.Fields.APPLICATIONS.getPreferredName(), roleDescriptor.getApplicationPrivileges()); builder.array(RoleDescriptor.Fields.RUN_AS.getPreferredName(), roleDescriptor.getRunAs()); if (roleDescriptor.getMetadata() != null && false == roleDescriptor.getMetadata().isEmpty()) { @@ -1401,15 +1420,7 @@ LogEntryBuilder withRequestBody(DeleteUserRequest deleteUserRequest) throws IOEx } LogEntryBuilder withRequestBody(DeleteRoleRequest deleteRoleRequest) throws IOException { - logEntry.with(EVENT_ACTION_FIELD_NAME, "delete_role"); - XContentBuilder builder = JsonXContent.contentBuilder().humanReadable(true); - builder.startObject() - .startObject("role") - .field("name", deleteRoleRequest.name()) - .endObject() // role - .endObject(); - logEntry.with(DELETE_CONFIG_FIELD_NAME, Strings.toString(builder)); - return this; + return withDeleteRole(deleteRoleRequest.name()); } LogEntryBuilder withRequestBody(DeleteRoleMappingRequest deleteRoleMappingRequest) throws IOException { @@ -1532,6 +1543,18 @@ LogEntryBuilder withRequestBody(SetProfileEnabledRequest setProfileEnabledReques return this; } + LogEntryBuilder withDeleteRole(String roleName) throws IOException { + logEntry.with(EVENT_ACTION_FIELD_NAME, "delete_role"); + XContentBuilder builder = JsonXContent.contentBuilder().humanReadable(true); + builder.startObject() + .startObject("role") + .field("name", roleName) + .endObject() // role + .endObject(); + logEntry.with(DELETE_CONFIG_FIELD_NAME, Strings.toString(builder)); + return this; + } + static void withGrant(XContentBuilder builder, Grant grant) throws IOException { builder.startObject("grant").field("type", grant.getType()); if (grant.getUsername() != null) { diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkSetLoader.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkSetLoader.java index 0266fc7488e29..063cc85ea0187 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkSetLoader.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkSetLoader.java @@ -22,7 +22,6 @@ import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.ssl.SSLService; import java.io.IOException; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkValidateUtil.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkValidateUtil.java index cc07b7dfa8381..89391f91a2731 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkValidateUtil.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwkValidateUtil.java @@ -24,7 +24,6 @@ import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.SettingsException; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import java.nio.charset.StandardCharsets; import java.security.PublicKey; diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtAuthenticationToken.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticationToken.java similarity index 98% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtAuthenticationToken.java rename to x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticationToken.java index ebfaae72b9df2..cfef9aed5967a 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtAuthenticationToken.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticationToken.java @@ -4,7 +4,7 @@ * 2.0; you may not use this file except in compliance with the Elastic License * 2.0. */ -package org.elasticsearch.xpack.core.security.authc.jwt; +package org.elasticsearch.xpack.security.authc.jwt; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java index b06aba1c9d87a..2345add07ba51 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticator.java @@ -19,7 +19,6 @@ import org.elasticsearch.core.Releasable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.xpack.core.security.authc.RealmConfig; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.ssl.SSLService; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java index 30a7e438e70b0..7613e7b3972af 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealm.java @@ -31,9 +31,7 @@ import org.elasticsearch.xpack.core.security.authc.Realm; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.security.authc.support.CachingRealm; import org.elasticsearch.xpack.core.security.authc.support.UserRoleMapper; import org.elasticsearch.xpack.core.security.support.CacheIteratorHelper; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtSignatureValidator.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtSignatureValidator.java index e183ee7d73ac2..b1ee1b77998ec 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtSignatureValidator.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtSignatureValidator.java @@ -35,14 +35,13 @@ import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.ssl.SSLService; import java.util.Arrays; import java.util.List; import java.util.stream.Stream; -import static org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil.toStringRedactSignature; +import static org.elasticsearch.xpack.security.authc.jwt.JwtUtil.toStringRedactSignature; public interface JwtSignatureValidator extends Releasable { diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtUtil.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtil.java similarity index 99% rename from x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtUtil.java rename to x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtil.java index d70b76f8bc574..928ecd7fa265d 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/authc/jwt/JwtUtil.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtil.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.core.security.authc.jwt; +package org.elasticsearch.xpack.security.authc.jwt; import com.nimbusds.jose.JWSObject; import com.nimbusds.jose.jwk.JWK; @@ -47,6 +47,7 @@ import org.elasticsearch.env.Environment; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; +import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.ssl.SSLService; import java.io.InputStream; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java index e637bda19d886..0f34850b861b7 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authc/oidc/OpenIdConnectAuthenticator.java @@ -91,9 +91,9 @@ import org.elasticsearch.watcher.ResourceWatcherService; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.security.authc.oidc.OpenIdConnectRealmSettings; import org.elasticsearch.xpack.core.ssl.SSLService; +import org.elasticsearch.xpack.security.authc.jwt.JwtUtil; import java.io.IOException; import java.net.URI; diff --git a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStore.java b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStore.java index adeada6cbf6cf..a2d2b21b489ea 100644 --- a/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStore.java +++ b/x-pack/plugin/security/src/main/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStore.java @@ -59,6 +59,7 @@ import org.elasticsearch.xpack.core.security.action.role.RoleDescriptorRequestValidator; import org.elasticsearch.xpack.core.security.authz.RoleDescriptor; import org.elasticsearch.xpack.core.security.authz.RoleDescriptor.IndicesPrivileges; +import org.elasticsearch.xpack.core.security.authz.permission.RemoteClusterPermissions; import org.elasticsearch.xpack.core.security.authz.store.RoleRetrievalResult; import org.elasticsearch.xpack.core.security.authz.support.DLSRoleQueryValidator; import org.elasticsearch.xpack.core.security.support.NativeRealmValidationUtil; @@ -607,16 +608,41 @@ private DeleteRequest createRoleDeleteRequest(final String roleName) { return client.prepareDelete(SECURITY_MAIN_ALIAS, getIdForRole(roleName)).request(); } - private XContentBuilder createRoleXContentBuilder(RoleDescriptor role) throws IOException { + // Package private for testing + XContentBuilder createRoleXContentBuilder(RoleDescriptor role) throws IOException { assert NativeRealmValidationUtil.validateRoleName(role.getName(), false) == null : "Role name was invalid or reserved: " + role.getName(); assert false == role.hasRestriction() : "restriction is not supported for native roles"; - return role.toXContent( - jsonBuilder(), - ToXContent.EMPTY_PARAMS, - true, - featureService.clusterHasFeature(clusterService.state(), SECURITY_ROLES_METADATA_FLATTENED) - ); + + XContentBuilder builder = jsonBuilder().startObject(); + role.innerToXContent(builder, ToXContent.EMPTY_PARAMS, true); + + if (featureService.clusterHasFeature(clusterService.state(), SECURITY_ROLES_METADATA_FLATTENED)) { + builder.field(RoleDescriptor.Fields.METADATA_FLATTENED.getPreferredName(), role.getMetadata()); + } + + // When role descriptor XContent is generated for the security index all empty fields need to have default values to make sure + // existing values are overwritten if not present since the request to update could be an UpdateRequest + // (update provided fields in existing document or create document) or IndexRequest (replace and reindex document) + if (role.hasConfigurableClusterPrivileges() == false) { + builder.startObject(RoleDescriptor.Fields.GLOBAL.getPreferredName()).endObject(); + } + + if (role.hasRemoteIndicesPrivileges() == false) { + builder.field(RoleDescriptor.Fields.REMOTE_INDICES.getPreferredName(), RoleDescriptor.RemoteIndicesPrivileges.NONE); + } + + if (role.hasRemoteClusterPermissions() == false + && clusterService.state().getMinTransportVersion().onOrAfter(ROLE_REMOTE_CLUSTER_PRIVS)) { + builder.array(RoleDescriptor.Fields.REMOTE_CLUSTER.getPreferredName(), RemoteClusterPermissions.NONE); + } + if (role.hasDescription() == false + && clusterService.state().getMinTransportVersion().onOrAfter(TransportVersions.SECURITY_ROLE_DESCRIPTION)) { + builder.field(RoleDescriptor.Fields.DESCRIPTION.getPreferredName(), ""); + } + + builder.endObject(); + return builder; } public void usageStats(ActionListener> listener) { diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrailTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrailTests.java index a3292a6ab5f4e..17bad90415e7c 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrailTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/audit/logfile/LoggingAuditTrailTests.java @@ -47,6 +47,7 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.XPackSettings; +import org.elasticsearch.xpack.core.security.action.ActionTypes; import org.elasticsearch.xpack.core.security.action.apikey.ApiKeyTests; import org.elasticsearch.xpack.core.security.action.apikey.BulkUpdateApiKeyAction; import org.elasticsearch.xpack.core.security.action.apikey.BulkUpdateApiKeyRequest; @@ -73,6 +74,8 @@ import org.elasticsearch.xpack.core.security.action.profile.SetProfileEnabledRequest; import org.elasticsearch.xpack.core.security.action.profile.UpdateProfileDataAction; import org.elasticsearch.xpack.core.security.action.profile.UpdateProfileDataRequest; +import org.elasticsearch.xpack.core.security.action.role.BulkDeleteRolesRequest; +import org.elasticsearch.xpack.core.security.action.role.BulkPutRolesRequest; import org.elasticsearch.xpack.core.security.action.role.DeleteRoleAction; import org.elasticsearch.xpack.core.security.action.role.DeleteRoleRequest; import org.elasticsearch.xpack.core.security.action.role.PutRoleAction; @@ -772,20 +775,19 @@ public void testSecurityConfigChangeEventFormattingForRoles() throws IOException auditTrail.accessGranted(requestId, authentication, PutRoleAction.NAME, putRoleRequest, authorizationInfo); output = CapturingLogger.output(logger.getName(), Level.INFO); assertThat(output.size(), is(2)); - String generatedPutRoleAuditEventString = output.get(1); - String expectedPutRoleAuditEventString = Strings.format(""" - "put":{"role":{"name":"%s","role_descriptor":%s}}\ - """, putRoleRequest.name(), auditedRolesMap.get(putRoleRequest.name())); - assertThat(generatedPutRoleAuditEventString, containsString(expectedPutRoleAuditEventString)); - generatedPutRoleAuditEventString = generatedPutRoleAuditEventString.replace(", " + expectedPutRoleAuditEventString, ""); - checkedFields = new HashMap<>(commonFields); - checkedFields.remove(LoggingAuditTrail.ORIGIN_ADDRESS_FIELD_NAME); - checkedFields.remove(LoggingAuditTrail.ORIGIN_TYPE_FIELD_NAME); - checkedFields.put("type", "audit"); - checkedFields.put(LoggingAuditTrail.EVENT_TYPE_FIELD_NAME, "security_config_change"); - checkedFields.put(LoggingAuditTrail.EVENT_ACTION_FIELD_NAME, "put_role"); - checkedFields.put(LoggingAuditTrail.REQUEST_ID_FIELD_NAME, requestId); - assertMsg(generatedPutRoleAuditEventString, checkedFields); + assertPutRoleAuditLogLine(putRoleRequest.name(), output.get(1), auditedRolesMap, requestId); + // clear log + CapturingLogger.output(logger.getName(), Level.INFO).clear(); + + BulkPutRolesRequest bulkPutRolesRequest = new BulkPutRolesRequest(allTestRoleDescriptors); + bulkPutRolesRequest.setRefreshPolicy(randomFrom(WriteRequest.RefreshPolicy.values())); + auditTrail.accessGranted(requestId, authentication, ActionTypes.BULK_PUT_ROLES.name(), bulkPutRolesRequest, authorizationInfo); + output = CapturingLogger.output(logger.getName(), Level.INFO); + assertThat(output.size(), is(allTestRoleDescriptors.size() + 1)); + + for (int i = 0; i < allTestRoleDescriptors.size(); i++) { + assertPutRoleAuditLogLine(allTestRoleDescriptors.get(i).getName(), output.get(i + 1), auditedRolesMap, requestId); + } // clear log CapturingLogger.output(logger.getName(), Level.INFO).clear(); @@ -795,25 +797,64 @@ public void testSecurityConfigChangeEventFormattingForRoles() throws IOException auditTrail.accessGranted(requestId, authentication, DeleteRoleAction.NAME, deleteRoleRequest, authorizationInfo); output = CapturingLogger.output(logger.getName(), Level.INFO); assertThat(output.size(), is(2)); - String generatedDeleteRoleAuditEventString = output.get(1); + assertDeleteRoleAuditLogLine(putRoleRequest.name(), output.get(1), requestId); + // clear log + CapturingLogger.output(logger.getName(), Level.INFO).clear(); + + BulkDeleteRolesRequest bulkDeleteRolesRequest = new BulkDeleteRolesRequest( + allTestRoleDescriptors.stream().map(RoleDescriptor::getName).toList() + ); + bulkDeleteRolesRequest.setRefreshPolicy(randomFrom(WriteRequest.RefreshPolicy.values())); + auditTrail.accessGranted( + requestId, + authentication, + ActionTypes.BULK_DELETE_ROLES.name(), + bulkDeleteRolesRequest, + authorizationInfo + ); + output = CapturingLogger.output(logger.getName(), Level.INFO); + assertThat(output.size(), is(allTestRoleDescriptors.size() + 1)); + for (int i = 0; i < allTestRoleDescriptors.size(); i++) { + assertDeleteRoleAuditLogLine(allTestRoleDescriptors.get(i).getName(), output.get(i + 1), requestId); + } + } + + private void assertPutRoleAuditLogLine(String roleName, String logLine, Map expectedLogByRoleName, String requestId) { + String expectedPutRoleAuditEventString = Strings.format(""" + "put":{"role":{"name":"%s","role_descriptor":%s}}\ + """, roleName, expectedLogByRoleName.get(roleName)); + + assertThat(logLine, containsString(expectedPutRoleAuditEventString)); + String reducedLogLine = logLine.replace(", " + expectedPutRoleAuditEventString, ""); + Map checkedFields = new HashMap<>(commonFields); + checkedFields.remove(LoggingAuditTrail.ORIGIN_ADDRESS_FIELD_NAME); + checkedFields.remove(LoggingAuditTrail.ORIGIN_TYPE_FIELD_NAME); + checkedFields.put("type", "audit"); + checkedFields.put(LoggingAuditTrail.EVENT_TYPE_FIELD_NAME, "security_config_change"); + checkedFields.put(LoggingAuditTrail.EVENT_ACTION_FIELD_NAME, "put_role"); + checkedFields.put(LoggingAuditTrail.REQUEST_ID_FIELD_NAME, requestId); + assertMsg(reducedLogLine, checkedFields); + } + + private void assertDeleteRoleAuditLogLine(String roleName, String logLine, String requestId) { StringBuilder deleteRoleStringBuilder = new StringBuilder().append("\"delete\":{\"role\":{\"name\":"); - if (deleteRoleRequest.name() == null) { + if (roleName == null) { deleteRoleStringBuilder.append("null"); } else { - deleteRoleStringBuilder.append("\"").append(deleteRoleRequest.name()).append("\""); + deleteRoleStringBuilder.append("\"").append(roleName).append("\""); } deleteRoleStringBuilder.append("}}"); String expectedDeleteRoleAuditEventString = deleteRoleStringBuilder.toString(); - assertThat(generatedDeleteRoleAuditEventString, containsString(expectedDeleteRoleAuditEventString)); - generatedDeleteRoleAuditEventString = generatedDeleteRoleAuditEventString.replace(", " + expectedDeleteRoleAuditEventString, ""); - checkedFields = new HashMap<>(commonFields); + assertThat(logLine, containsString(expectedDeleteRoleAuditEventString)); + String reducedLogLine = logLine.replace(", " + expectedDeleteRoleAuditEventString, ""); + Map checkedFields = new HashMap<>(commonFields); checkedFields.remove(LoggingAuditTrail.ORIGIN_ADDRESS_FIELD_NAME); checkedFields.remove(LoggingAuditTrail.ORIGIN_TYPE_FIELD_NAME); checkedFields.put("type", "audit"); checkedFields.put(LoggingAuditTrail.EVENT_TYPE_FIELD_NAME, "security_config_change"); checkedFields.put(LoggingAuditTrail.EVENT_ACTION_FIELD_NAME, "delete_role"); checkedFields.put(LoggingAuditTrail.REQUEST_ID_FIELD_NAME, requestId); - assertMsg(generatedDeleteRoleAuditEventString, checkedFields); + assertMsg(reducedLogLine, checkedFields); } public void testSecurityConfigChangeEventForCrossClusterApiKeys() throws IOException { @@ -1975,6 +2016,11 @@ public void testSecurityConfigChangedEventSelection() { Tuple actionAndRequest = randomFrom( new Tuple<>(PutUserAction.NAME, new PutUserRequest()), new Tuple<>(PutRoleAction.NAME, new PutRoleRequest()), + new Tuple<>( + ActionTypes.BULK_PUT_ROLES.name(), + new BulkPutRolesRequest(List.of(new RoleDescriptor(randomAlphaOfLength(20), null, null, null))) + ), + new Tuple<>(ActionTypes.BULK_DELETE_ROLES.name(), new BulkDeleteRolesRequest(List.of(randomAlphaOfLength(20)))), new Tuple<>(PutRoleMappingAction.NAME, new PutRoleMappingRequest()), new Tuple<>(TransportSetEnabledAction.TYPE.name(), new SetEnabledRequest()), new Tuple<>(TransportChangePasswordAction.TYPE.name(), new ChangePasswordRequest()), diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java index 7a44ebae95738..6d4861212e286 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtAuthenticatorTests.java @@ -24,7 +24,6 @@ import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.ssl.SSLService; import org.junit.Before; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuer.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuer.java index 3d4d9eae6acd0..789ac04c40622 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuer.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtIssuer.java @@ -14,7 +14,6 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.security.user.User; import java.io.Closeable; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java index bf6c64242701b..4f7b82a16e8f1 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmAuthenticateTests.java @@ -25,7 +25,6 @@ import org.elasticsearch.xpack.core.security.authc.AuthenticationToken; import org.elasticsearch.xpack.core.security.authc.Realm; import org.elasticsearch.xpack.core.security.authc.RealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.security.user.User; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java index 7a0e138305b83..8a5daa642002e 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmGenerateTests.java @@ -23,7 +23,6 @@ import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.security.authc.support.DelegatedAuthorizationSettings; import org.elasticsearch.xpack.core.security.authc.support.UserRoleMapper; import org.elasticsearch.xpack.core.security.user.User; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmInspector.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmInspector.java index 40a613a0907c8..7697849179acf 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmInspector.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmInspector.java @@ -11,7 +11,6 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.Setting; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import org.elasticsearch.xpack.core.security.authc.support.ClaimSetting; import java.net.URI; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java index 1bc49cb628464..ffc1fec1f5788 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtRealmTestCase.java @@ -28,7 +28,6 @@ import org.elasticsearch.xpack.core.security.authc.Realm; import org.elasticsearch.xpack.core.security.authc.RealmConfig; import org.elasticsearch.xpack.core.security.authc.RealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtAuthenticationToken; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings.ClientAuthenticationType; import org.elasticsearch.xpack.core.security.authc.support.DelegatedAuthorizationSettings; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtilTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtilTests.java index 7d90dffd7517c..6fab33b4d6adf 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtilTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authc/jwt/JwtUtilTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.common.settings.SettingsException; import org.elasticsearch.xpack.core.security.authc.jwt.JwtRealmSettings; -import org.elasticsearch.xpack.core.security.authc.jwt.JwtUtil; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; diff --git a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStoreTests.java b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStoreTests.java index a4ee449438fe0..bfa358d0b7d6e 100644 --- a/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStoreTests.java +++ b/x-pack/plugin/security/src/test/java/org/elasticsearch/xpack/security/authz/store/NativeRolesStoreTests.java @@ -55,6 +55,7 @@ import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.NamedXContentRegistry; +import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; @@ -78,6 +79,7 @@ import org.mockito.Mockito; import java.io.IOException; +import java.lang.reflect.Field; import java.nio.charset.Charset; import java.nio.file.Files; import java.nio.file.Path; @@ -138,7 +140,7 @@ private NativeRolesStore createRoleStoreForTest() { private NativeRolesStore createRoleStoreForTest(Settings settings) { new ReservedRolesStore(Set.of("superuser")); - final ClusterService clusterService = mock(ClusterService.class); + final ClusterService clusterService = mockClusterServiceWithMinNodeVersion(TransportVersion.current()); final SecuritySystemIndices systemIndices = new SecuritySystemIndices(settings); final FeatureService featureService = mock(FeatureService.class); systemIndices.init(client, featureService, clusterService); @@ -807,6 +809,62 @@ public void testBulkDeleteReservedRole() { verify(client, times(0)).bulk(any(BulkRequest.class), any()); } + /** + * Make sure all top level fields for a RoleDescriptor have default values to make sure they can be set to empty in an upsert + * call to the roles API + */ + public void testAllTopFieldsHaveEmptyDefaultsForUpsert() throws IOException, IllegalAccessException { + final NativeRolesStore rolesStore = createRoleStoreForTest(); + RoleDescriptor allNullDescriptor = new RoleDescriptor( + "all-null-descriptor", + null, + null, + null, + null, + null, + null, + null, + null, + null, + null, + null + ); + + Set fieldsWithoutDefaultValue = Set.of( + RoleDescriptor.Fields.INDEX, + RoleDescriptor.Fields.NAMES, + RoleDescriptor.Fields.ALLOW_RESTRICTED_INDICES, + RoleDescriptor.Fields.RESOURCES, + RoleDescriptor.Fields.QUERY, + RoleDescriptor.Fields.PRIVILEGES, + RoleDescriptor.Fields.CLUSTERS, + RoleDescriptor.Fields.APPLICATION, + RoleDescriptor.Fields.FIELD_PERMISSIONS, + RoleDescriptor.Fields.FIELD_PERMISSIONS_2X, + RoleDescriptor.Fields.GRANT_FIELDS, + RoleDescriptor.Fields.EXCEPT_FIELDS, + RoleDescriptor.Fields.METADATA_FLATTENED, + RoleDescriptor.Fields.TRANSIENT_METADATA, + RoleDescriptor.Fields.RESTRICTION, + RoleDescriptor.Fields.WORKFLOWS + ); + + String serializedOutput = Strings.toString(rolesStore.createRoleXContentBuilder(allNullDescriptor)); + Field[] fields = RoleDescriptor.Fields.class.getFields(); + + for (Field field : fields) { + ParseField fieldValue = (ParseField) field.get(null); + if (fieldsWithoutDefaultValue.contains(fieldValue) == false) { + assertThat( + "New RoleDescriptor field without a default value detected. " + + "Set a value or add to excluded list if not expected to be set to empty through role APIs", + serializedOutput, + containsString(fieldValue.getPreferredName()) + ); + } + } + } + private ClusterService mockClusterServiceWithMinNodeVersion(TransportVersion transportVersion) { final ClusterService clusterService = mock(ClusterService.class, Mockito.RETURNS_DEEP_STUBS); when(clusterService.state().getMinTransportVersion()).thenReturn(transportVersion); diff --git a/x-pack/plugin/snapshot-repo-test-kit/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/10_analyze.yml b/x-pack/plugin/snapshot-repo-test-kit/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/10_analyze.yml index e5babad76eb05..bcee1691e033c 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/10_analyze.yml +++ b/x-pack/plugin/snapshot-repo-test-kit/qa/rest/src/yamlRestTest/resources/rest-api-spec/test/10_analyze.yml @@ -175,6 +175,6 @@ setup: - match: { status: 500 } - match: { error.type: repository_verification_exception } - - match: { error.reason: "/.*test_repo_slow..analysis.failed.*/" } + - match: { error.reason: "/.*test_repo_slow..Repository.analysis.timed.out.*/" } - match: { error.root_cause.0.type: repository_verification_exception } - match: { error.root_cause.0.reason: "/.*test_repo_slow..analysis.timed.out.after..1s.*/" } diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalysisFailureIT.java b/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalysisFailureIT.java index 7715b9e8d42b8..2ca5685c83db3 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalysisFailureIT.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/internalClusterTest/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalysisFailureIT.java @@ -11,6 +11,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.cluster.metadata.RepositoryMetadata; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ReferenceDocs; import org.elasticsearch.common.blobstore.BlobContainer; import org.elasticsearch.common.blobstore.BlobPath; import org.elasticsearch.common.blobstore.BlobStore; @@ -363,6 +364,17 @@ public BytesReference onContendedCompareAndExchange(BytesRegister register, Byte } } + private static void assertAnalysisFailureMessage(String message) { + assertThat( + message, + allOf( + containsString("Elasticsearch observed the storage system underneath this repository behaved incorrectly"), + containsString("not suitable for use with Elasticsearch snapshots"), + containsString(ReferenceDocs.SNAPSHOT_REPOSITORY_ANALYSIS.toString()) + ) + ); + } + public void testTimesOutSpinningRegisterAnalysis() { final RepositoryAnalyzeAction.Request request = new RepositoryAnalyzeAction.Request("test-repo"); request.timeout(TimeValue.timeValueMillis(between(1, 1000))); @@ -375,7 +387,13 @@ public boolean compareAndExchangeReturnsWitness(String key) { } }); final var exception = expectThrows(RepositoryVerificationException.class, () -> analyseRepository(request)); - assertThat(exception.getMessage(), containsString("analysis failed")); + assertThat( + exception.getMessage(), + allOf( + containsString("Repository analysis timed out. Consider specifying a longer timeout"), + containsString(ReferenceDocs.SNAPSHOT_REPOSITORY_ANALYSIS.toString()) + ) + ); assertThat( asInstanceOf(RepositoryVerificationException.class, exception.getCause()).getMessage(), containsString("analysis timed out") @@ -391,7 +409,7 @@ public boolean compareAndExchangeReturnsWitness(String key) { } }); final var exception = expectThrows(RepositoryVerificationException.class, () -> analyseRepository(request)); - assertThat(exception.getMessage(), containsString("analysis failed")); + assertAnalysisFailureMessage(exception.getMessage()); assertThat( asInstanceOf(RepositoryVerificationException.class, ExceptionsHelper.unwrapCause(exception.getCause())).getMessage(), allOf(containsString("uncontended register operation failed"), containsString("did not observe any value")) @@ -407,7 +425,7 @@ public boolean acceptsEmptyRegister() { } }); final var exception = expectThrows(RepositoryVerificationException.class, () -> analyseRepository(request)); - assertThat(exception.getMessage(), containsString("analysis failed")); + assertAnalysisFailureMessage(exception.getMessage()); final var cause = ExceptionsHelper.unwrapCause(exception.getCause()); if (cause instanceof IOException ioException) { assertThat(ioException.getMessage(), containsString("empty register update rejected")); diff --git a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalyzeAction.java b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalyzeAction.java index 7b82b69a682fa..494d1d3fedcd9 100644 --- a/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalyzeAction.java +++ b/x-pack/plugin/snapshot-repo-test-kit/src/main/java/org/elasticsearch/repositories/blobstore/testkit/RepositoryAnalyzeAction.java @@ -28,6 +28,7 @@ import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; +import org.elasticsearch.common.ReferenceDocs; import org.elasticsearch.common.Strings; import org.elasticsearch.common.UUIDs; import org.elasticsearch.common.blobstore.BlobContainer; @@ -387,6 +388,9 @@ public static class AsyncAction { private final List responses; private final RepositoryPerformanceSummary.Builder summary = new RepositoryPerformanceSummary.Builder(); + private final RepositoryVerificationException analysisCancelledException; + private final RepositoryVerificationException analysisTimedOutException; + public AsyncAction( TransportService transportService, BlobStoreRepository repository, @@ -410,6 +414,12 @@ public AsyncAction( this.listener = ActionListener.runBefore(listener, () -> cancellationListener.onResponse(null)); responses = new ArrayList<>(request.blobCount); + + this.analysisCancelledException = new RepositoryVerificationException(request.repositoryName, "analysis cancelled"); + this.analysisTimedOutException = new RepositoryVerificationException( + request.repositoryName, + "analysis timed out after [" + request.getTimeout() + "]" + ); } private boolean setFirstFailure(Exception e) { @@ -453,12 +463,7 @@ public void onFailure(Exception e) { assert e instanceof ElasticsearchTimeoutException : e; if (isRunning()) { // if this CAS fails then we're already failing for some other reason, nbd - setFirstFailure( - new RepositoryVerificationException( - request.repositoryName, - "analysis timed out after [" + request.getTimeout() + "]" - ) - ); + setFirstFailure(analysisTimedOutException); } } } @@ -472,7 +477,7 @@ public void run() { cancellationListener.addTimeout(request.getTimeout(), repository.threadPool(), EsExecutors.DIRECT_EXECUTOR_SERVICE); cancellationListener.addListener(new CheckForCancelListener()); - task.addListener(() -> setFirstFailure(new RepositoryVerificationException(request.repositoryName, "analysis cancelled"))); + task.addListener(() -> setFirstFailure(analysisCancelledException)); final Random random = new Random(request.getSeed()); final List nodes = getSnapshotNodes(discoveryNodes); @@ -873,13 +878,20 @@ private void sendResponse(final long listingStartTimeNanos, final long deleteSta ); } else { logger.debug(() -> "analysis of repository [" + request.repositoryName + "] failed", exception); - listener.onFailure( - new RepositoryVerificationException( - request.getRepositoryName(), - "analysis failed, you may need to manually remove [" + blobPath + "]", - exception - ) - ); + + final String failureDetail; + if (exception == analysisCancelledException) { + failureDetail = "Repository analysis was cancelled."; + } else if (exception == analysisTimedOutException) { + failureDetail = Strings.format(""" + Repository analysis timed out. Consider specifying a longer timeout using the [?timeout] request parameter. See \ + [%s] for more information.""", ReferenceDocs.SNAPSHOT_REPOSITORY_ANALYSIS); + } else { + failureDetail = repository.getAnalysisFailureExtraDetail(); + } + listener.onFailure(new RepositoryVerificationException(request.getRepositoryName(), Strings.format(""" + %s Elasticsearch attempted to remove the data it wrote at [%s] but may have left some behind. If so, \ + please now remove this data manually.""", failureDetail, blobPath), exception)); } } } diff --git a/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/index/mapper/GeoShapeWithDocValuesFieldMapper.java b/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/index/mapper/GeoShapeWithDocValuesFieldMapper.java index adafacf92fe4f..04b194c2ec208 100644 --- a/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/index/mapper/GeoShapeWithDocValuesFieldMapper.java +++ b/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/index/mapper/GeoShapeWithDocValuesFieldMapper.java @@ -24,7 +24,6 @@ import org.elasticsearch.common.geo.Orientation; import org.elasticsearch.common.geo.ShapeRelation; import org.elasticsearch.common.logging.DeprecationCategory; -import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.geometry.Geometry; import org.elasticsearch.geometry.utils.GeometryValidator; import org.elasticsearch.geometry.utils.WellKnownBinary; @@ -36,7 +35,6 @@ import org.elasticsearch.index.mapper.AbstractShapeGeometryFieldMapper; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.GeoShapeFieldMapper; import org.elasticsearch.index.mapper.GeoShapeIndexer; import org.elasticsearch.index.mapper.GeoShapeParser; import org.elasticsearch.index.mapper.GeoShapeQueryable; @@ -80,14 +78,8 @@ import java.util.function.Function; /** - * Extension of {@link org.elasticsearch.index.mapper.GeoShapeFieldMapper} that supports docValues - * * FieldMapper for indexing {@link LatLonShape}s. *

- * Currently Shapes can only be indexed and can only be queried using - * {@link org.elasticsearch.index.query.GeoShapeQueryBuilder}, consequently - * a lot of behavior in this Mapper is disabled. - *

* Format supported: *

* "field" : { @@ -104,8 +96,6 @@ public class GeoShapeWithDocValuesFieldMapper extends AbstractShapeGeometryFieldMapper { public static final String CONTENT_TYPE = "geo_shape"; - private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(GeoShapeFieldMapper.class); - private static Builder builder(FieldMapper in) { return ((GeoShapeWithDocValuesFieldMapper) in).builder; } diff --git a/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/index/mapper/PointFieldMapper.java b/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/index/mapper/PointFieldMapper.java index de49e0c5a5563..d98fe7fdfc6ec 100644 --- a/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/index/mapper/PointFieldMapper.java +++ b/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/index/mapper/PointFieldMapper.java @@ -14,7 +14,6 @@ import org.elasticsearch.common.geo.GeometryFormatterFactory; import org.elasticsearch.common.geo.ShapeRelation; import org.elasticsearch.common.logging.DeprecationCategory; -import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.core.CheckedFunction; import org.elasticsearch.geometry.Geometry; import org.elasticsearch.geometry.Point; @@ -23,7 +22,6 @@ import org.elasticsearch.index.mapper.AbstractPointGeometryFieldMapper; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.GeoShapeFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.MapperBuilderContext; import org.elasticsearch.index.query.SearchExecutionContext; @@ -49,8 +47,6 @@ public class PointFieldMapper extends AbstractPointGeometryFieldMapper { public static final String CONTENT_TYPE = "point"; - private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(GeoShapeFieldMapper.class); - private static Builder builder(FieldMapper in) { return ((PointFieldMapper) in).builder; } diff --git a/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/index/mapper/ShapeFieldMapper.java b/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/index/mapper/ShapeFieldMapper.java index 4cc983592d0c1..91a118f964064 100644 --- a/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/index/mapper/ShapeFieldMapper.java +++ b/x-pack/plugin/spatial/src/main/java/org/elasticsearch/xpack/spatial/index/mapper/ShapeFieldMapper.java @@ -15,7 +15,6 @@ import org.elasticsearch.common.geo.Orientation; import org.elasticsearch.common.geo.ShapeRelation; import org.elasticsearch.common.logging.DeprecationCategory; -import org.elasticsearch.common.logging.DeprecationLogger; import org.elasticsearch.geometry.Geometry; import org.elasticsearch.index.IndexVersion; import org.elasticsearch.index.IndexVersions; @@ -25,7 +24,6 @@ import org.elasticsearch.index.mapper.AbstractShapeGeometryFieldMapper; import org.elasticsearch.index.mapper.DocumentParserContext; import org.elasticsearch.index.mapper.FieldMapper; -import org.elasticsearch.index.mapper.GeoShapeFieldMapper; import org.elasticsearch.index.mapper.MappedFieldType; import org.elasticsearch.index.mapper.MapperBuilderContext; import org.elasticsearch.index.query.SearchExecutionContext; @@ -70,8 +68,6 @@ public class ShapeFieldMapper extends AbstractShapeGeometryFieldMapper { public static final String CONTENT_TYPE = "shape"; - private static final DeprecationLogger DEPRECATION_LOGGER = DeprecationLogger.getLogger(GeoShapeFieldMapper.class); - private static Builder builder(FieldMapper in) { return ((ShapeFieldMapper) in).builder; } diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/100_bug_fix.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/100_bug_fix.yml index b91343d03d3d4..cffc161b11539 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/100_bug_fix.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/100_bug_fix.yml @@ -303,3 +303,38 @@ - match: { values.0.2: [1, 2] } - match: { values.0.3: [1, 2] } - match: { values.0.4: [1.1, 2.2] } + + +--- +"grok with duplicate names and different types #110533": + - requires: + test_runner_features: [capabilities] + capabilities: + - method: POST + path: /_query + parameters: [] + capabilities: [grok_validation] + reason: "fixed grok validation with patterns containing the same attribute multiple times with different types" + - do: + indices.create: + index: test_grok + body: + mappings: + properties: + first_name : + type : keyword + last_name: + type: keyword + + - do: + bulk: + refresh: true + body: + - { "index": { "_index": "test_grok" } } + - { "first_name": "Georgi", "last_name":"Facello" } + + - do: + catch: '/Invalid GROK pattern \[%\{NUMBER:foo\} %\{WORD:foo\}\]: the attribute \[foo\] is defined multiple times with different types/' + esql.query: + body: + query: 'FROM test_grok | KEEP name | WHERE last_name == "Facello" | EVAL name = concat("1 ", last_name) | GROK name "%{NUMBER:foo} %{WORD:foo}"' diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/160_union_types.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/160_union_types.yml index f3403ca8751c0..aac60d9aaa8d0 100644 --- a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/160_union_types.yml +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/esql/160_union_types.yml @@ -147,6 +147,9 @@ setup: - '{"index": {}}' - '{"@timestamp": "2023-10-23T12:15:03.360Z", "client_ip": "172.21.2.162", "event_duration": "3450233", "message": "Connected to 10.1.0.3"}' +############################################################################################################ +# Test a single index as a control of the expected results + --- load single index ip_long: - do: @@ -173,9 +176,6 @@ load single index ip_long: - match: { values.0.3: 1756467 } - match: { values.0.4: "Connected to 10.1.0.1" } -############################################################################################################ -# Test a single index as a control of the expected results - --- load single index keyword_keyword: - do: @@ -202,6 +202,83 @@ load single index keyword_keyword: - match: { values.0.3: "1756467" } - match: { values.0.4: "Connected to 10.1.0.1" } +--- +load single index ip_long and aggregate by client_ip: + - requires: + capabilities: + - method: POST + path: /_query + parameters: [method, path, parameters, capabilities] + capabilities: [casting_operator] + reason: "Casting operator and introduced in 8.15.0" + - do: + allowed_warnings_regex: + - "No limit defined, adding default limit of \\[.*\\]" + esql.query: + body: + query: 'FROM events_ip_long | STATS count = COUNT(*) BY client_ip::ip | SORT count DESC, `client_ip::ip` ASC' + + - match: { columns.0.name: "count" } + - match: { columns.0.type: "long" } + - match: { columns.1.name: "client_ip::ip" } + - match: { columns.1.type: "ip" } + - length: { values: 4 } + - match: { values.0.0: 4 } + - match: { values.0.1: "172.21.3.15" } + - match: { values.1.0: 1 } + - match: { values.1.1: "172.21.0.5" } + - match: { values.2.0: 1 } + - match: { values.2.1: "172.21.2.113" } + - match: { values.3.0: 1 } + - match: { values.3.1: "172.21.2.162" } + +--- +load single index ip_long and aggregate client_ip my message: + - requires: + capabilities: + - method: POST + path: /_query + parameters: [method, path, parameters, capabilities] + capabilities: [casting_operator] + reason: "Casting operator and introduced in 8.15.0" + - do: + allowed_warnings_regex: + - "No limit defined, adding default limit of \\[.*\\]" + esql.query: + body: + query: 'FROM events_ip_long | STATS count = COUNT(client_ip::ip) BY message | SORT count DESC, message ASC' + + - match: { columns.0.name: "count" } + - match: { columns.0.type: "long" } + - match: { columns.1.name: "message" } + - match: { columns.1.type: "keyword" } + - length: { values: 5 } + - match: { values.0.0: 3 } + - match: { values.0.1: "Connection error" } + - match: { values.1.0: 1 } + - match: { values.1.1: "Connected to 10.1.0.1" } + - match: { values.2.0: 1 } + - match: { values.2.1: "Connected to 10.1.0.2" } + - match: { values.3.0: 1 } + - match: { values.3.1: "Connected to 10.1.0.3" } + - match: { values.4.0: 1 } + - match: { values.4.1: "Disconnected" } + +--- +load single index ip_long stats invalid grouping: + - requires: + capabilities: + - method: POST + path: /_query + parameters: [method, path, parameters, capabilities] + capabilities: [casting_operator] + reason: "Casting operator and introduced in 8.15.0" + - do: + catch: '/Unknown column \[x\]/' + esql.query: + body: + query: 'FROM events_ip_long | STATS count = COUNT(client_ip::ip) BY x' + ############################################################################################################ # Test two indices where the event_duration is mapped as a LONG and as a KEYWORD @@ -512,6 +589,83 @@ load two indices, convert, rename but not drop ambiguous field client_ip: - match: { values.1.5: "172.21.3.15" } - match: { values.1.6: "172.21.3.15" } +--- +load two indexes and group by converted client_ip: + - requires: + capabilities: + - method: POST + path: /_query + parameters: [method, path, parameters, capabilities] + capabilities: [casting_operator, union_types_agg_cast] + reason: "Casting operator and Union types introduced in 8.15.0" + - do: + allowed_warnings_regex: + - "No limit defined, adding default limit of \\[.*\\]" + esql.query: + body: + query: 'FROM events_*_long | STATS count = COUNT(*) BY client_ip::ip | SORT count DESC, `client_ip::ip` ASC' + + - match: { columns.0.name: "count" } + - match: { columns.0.type: "long" } + - match: { columns.1.name: "client_ip::ip" } + - match: { columns.1.type: "ip" } + - length: { values: 4 } + - match: { values.0.0: 8 } + - match: { values.0.1: "172.21.3.15" } + - match: { values.1.0: 2 } + - match: { values.1.1: "172.21.0.5" } + - match: { values.2.0: 2 } + - match: { values.2.1: "172.21.2.113" } + - match: { values.3.0: 2 } + - match: { values.3.1: "172.21.2.162" } + +--- +load two indexes and aggregate converted client_ip: + - requires: + capabilities: + - method: POST + path: /_query + parameters: [method, path, parameters, capabilities] + capabilities: [casting_operator, union_types_agg_cast] + reason: "Casting operator and Union types introduced in 8.15.0" + - do: + allowed_warnings_regex: + - "No limit defined, adding default limit of \\[.*\\]" + esql.query: + body: + query: 'FROM events_*_long | STATS count = COUNT(client_ip::ip) BY message | SORT count DESC, message ASC' + + - match: { columns.0.name: "count" } + - match: { columns.0.type: "long" } + - match: { columns.1.name: "message" } + - match: { columns.1.type: "keyword" } + - length: { values: 5 } + - match: { values.0.0: 6 } + - match: { values.0.1: "Connection error" } + - match: { values.1.0: 2 } + - match: { values.1.1: "Connected to 10.1.0.1" } + - match: { values.2.0: 2 } + - match: { values.2.1: "Connected to 10.1.0.2" } + - match: { values.3.0: 2 } + - match: { values.3.1: "Connected to 10.1.0.3" } + - match: { values.4.0: 2 } + - match: { values.4.1: "Disconnected" } + +--- +load two indexes, convert client_ip and group by something invalid: + - requires: + capabilities: + - method: POST + path: /_query + parameters: [method, path, parameters, capabilities] + capabilities: [casting_operator, union_types_agg_cast] + reason: "Casting operator and Union types introduced in 8.15.0" + - do: + catch: '/Unknown column \[x\]/' + esql.query: + body: + query: 'FROM events_*_long | STATS count = COUNT(client_ip::ip) BY x' + ############################################################################################################ # Test four indices with both the client_IP (IP and KEYWORD) and event_duration (LONG and KEYWORD) mappings diff --git a/x-pack/plugin/stack/src/main/java/org/elasticsearch/xpack/stack/StackTemplateRegistry.java b/x-pack/plugin/stack/src/main/java/org/elasticsearch/xpack/stack/StackTemplateRegistry.java index aa1e8858163a5..648146ccdcc61 100644 --- a/x-pack/plugin/stack/src/main/java/org/elasticsearch/xpack/stack/StackTemplateRegistry.java +++ b/x-pack/plugin/stack/src/main/java/org/elasticsearch/xpack/stack/StackTemplateRegistry.java @@ -47,7 +47,7 @@ public class StackTemplateRegistry extends IndexTemplateRegistry { // The stack template registry version. This number must be incremented when we make changes // to built-in templates. - public static final int REGISTRY_VERSION = 11; + public static final int REGISTRY_VERSION = 12; public static final String TEMPLATE_VERSION_VARIABLE = "xpack.stack.template.version"; public static final Setting STACK_TEMPLATES_ENABLED = Setting.boolSetting( diff --git a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/common/AbstractCompositeAggFunction.java b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/common/AbstractCompositeAggFunction.java index 3412be813dcf6..23bab56de5ec9 100644 --- a/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/common/AbstractCompositeAggFunction.java +++ b/x-pack/plugin/transform/src/main/java/org/elasticsearch/xpack/transform/transforms/common/AbstractCompositeAggFunction.java @@ -7,6 +7,8 @@ package org.elasticsearch.xpack.transform.transforms.common; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; @@ -45,6 +47,7 @@ * Basic abstract class for implementing a transform function that utilizes composite aggregations */ public abstract class AbstractCompositeAggFunction implements Function { + private static final Logger logger = LogManager.getLogger(AbstractCompositeAggFunction.class); public static final int TEST_QUERY_PAGE_SIZE = 50; public static final String COMPOSITE_AGGREGATION_NAME = "_transform"; @@ -78,7 +81,7 @@ public void preview( ClientHelper.TRANSFORM_ORIGIN, client, TransportSearchAction.TYPE, - buildSearchRequest(sourceConfig, timeout, numberOfBuckets), + buildSearchRequestForValidation("preview", sourceConfig, timeout, numberOfBuckets), ActionListener.wrap(r -> { try { final InternalAggregations aggregations = r.getAggregations(); @@ -116,7 +119,7 @@ public void validateQuery( TimeValue timeout, ActionListener listener ) { - SearchRequest searchRequest = buildSearchRequest(sourceConfig, timeout, TEST_QUERY_PAGE_SIZE); + SearchRequest searchRequest = buildSearchRequestForValidation("validate", sourceConfig, timeout, TEST_QUERY_PAGE_SIZE); ClientHelper.executeWithHeadersAsync( headers, ClientHelper.TRANSFORM_ORIGIN, @@ -193,11 +196,12 @@ protected abstract Stream> extractResults( TransformProgress progress ); - private SearchRequest buildSearchRequest(SourceConfig sourceConfig, TimeValue timeout, int pageSize) { + private SearchRequest buildSearchRequestForValidation(String logId, SourceConfig sourceConfig, TimeValue timeout, int pageSize) { SearchSourceBuilder sourceBuilder = new SearchSourceBuilder().query(sourceConfig.getQueryConfig().getQuery()) .runtimeMappings(sourceConfig.getRuntimeMappings()) .timeout(timeout); buildSearchQuery(sourceBuilder, null, pageSize); + logger.debug("[{}] Querying {} for data: {}", logId, sourceConfig.getIndex(), sourceBuilder); return new SearchRequest(sourceConfig.getIndex()).source(sourceBuilder).indicesOptions(IndicesOptions.LENIENT_EXPAND_OPEN); } diff --git a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/trigger/schedule/engine/TickerScheduleTriggerEngine.java b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/trigger/schedule/engine/TickerScheduleTriggerEngine.java index ba07c3137340d..ced131640f0ee 100644 --- a/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/trigger/schedule/engine/TickerScheduleTriggerEngine.java +++ b/x-pack/plugin/watcher/src/main/java/org/elasticsearch/xpack/watcher/trigger/schedule/engine/TickerScheduleTriggerEngine.java @@ -34,6 +34,7 @@ import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; import static org.elasticsearch.common.settings.Setting.positiveTimeSetting; @@ -50,6 +51,7 @@ public class TickerScheduleTriggerEngine extends ScheduleTriggerEngine { private final TimeValue tickInterval; private final Map schedules = new ConcurrentHashMap<>(); private final Ticker ticker; + private final AtomicBoolean isRunning = new AtomicBoolean(false); public TickerScheduleTriggerEngine(Settings settings, ScheduleRegistry scheduleRegistry, Clock clock) { super(scheduleRegistry, clock); @@ -60,7 +62,8 @@ public TickerScheduleTriggerEngine(Settings settings, ScheduleRegistry scheduleR @Override public synchronized void start(Collection jobs) { long startTime = clock.millis(); - logger.info("Watcher starting watches at {}", WatcherDateTimeUtils.dateTimeFormatter.formatMillis(startTime)); + isRunning.set(true); + logger.info("Starting watcher engine at {}", WatcherDateTimeUtils.dateTimeFormatter.formatMillis(startTime)); Map startingSchedules = Maps.newMapWithExpectedSize(jobs.size()); for (Watch job : jobs) { if (job.trigger() instanceof ScheduleTrigger trigger) { @@ -81,17 +84,22 @@ public synchronized void start(Collection jobs) { @Override public void stop() { + logger.info("Stopping watcher engine"); + isRunning.set(false); schedules.clear(); ticker.close(); } @Override - public synchronized void pauseExecution() { + public void pauseExecution() { + logger.info("Pausing watcher engine"); + isRunning.set(false); schedules.clear(); } @Override public void add(Watch watch) { + logger.trace("Adding watch [{}] to engine (engine is running: {})", watch.id(), isRunning.get()); assert watch.trigger() instanceof ScheduleTrigger; ScheduleTrigger trigger = (ScheduleTrigger) watch.trigger(); ActiveSchedule currentSchedule = schedules.get(watch.id()); @@ -106,13 +114,25 @@ public void add(Watch watch) { @Override public boolean remove(String jobId) { + logger.debug("Removing watch [{}] from engine (engine is running: {})", jobId, isRunning.get()); return schedules.remove(jobId) != null; } void checkJobs() { + if (isRunning.get() == false) { + logger.debug( + "Watcher not running because the engine is paused. Currently scheduled watches being skipped: {}", + schedules.size() + ); + return; + } long triggeredTime = clock.millis(); List events = new ArrayList<>(); for (ActiveSchedule schedule : schedules.values()) { + if (isRunning.get() == false) { + logger.debug("Watcher paused while running [{}]", schedule.name); + break; + } long scheduledTime = schedule.check(triggeredTime); if (scheduledTime > 0) { ZonedDateTime triggeredDateTime = utcDateTimeAtEpochMillis(triggeredTime);