diff --git a/docs/changelog/109256.yaml b/docs/changelog/109256.yaml new file mode 100644 index 0000000000000..30c15ed77f9b9 --- /dev/null +++ b/docs/changelog/109256.yaml @@ -0,0 +1,7 @@ +pr: 109256 +summary: "[ESQL] Migrate `SimplifyComparisonArithmetics` optimization" +area: ES|QL +type: bug +issues: + - 108388 + - 108743 diff --git a/docs/changelog/109276.yaml b/docs/changelog/109276.yaml new file mode 100644 index 0000000000000..d73e68e3c8f7b --- /dev/null +++ b/docs/changelog/109276.yaml @@ -0,0 +1,5 @@ +pr: 109276 +summary: Add remove index setting command +area: Infra/Settings +type: enhancement +issues: [] diff --git a/docs/reference/commands/node-tool.asciidoc b/docs/reference/commands/node-tool.asciidoc index 810de4a71fffb..cdd2bb8f0f9d7 100644 --- a/docs/reference/commands/node-tool.asciidoc +++ b/docs/reference/commands/node-tool.asciidoc @@ -31,6 +31,10 @@ This tool has a number of modes: from the cluster state in case where it contains incompatible settings that prevent the cluster from forming. +* `elasticsearch-node remove-index-settings` can be used to remove index settings + from the cluster state in case where it contains incompatible index settings that + prevent the cluster from forming. + * `elasticsearch-node remove-customs` can be used to remove custom metadata from the cluster state in case where it contains broken metadata that prevents the cluster state from being loaded. @@ -107,6 +111,26 @@ The intended use is: * Repeat for all other master-eligible nodes * Start the nodes +[discrete] +==== Removing index settings + +There may be situations where an index contains index settings +that prevent the cluster from forming. Since the cluster cannot form, +it is not possible to remove these settings using the +<> API. + +The `elasticsearch-node remove-index-settings` tool allows you to forcefully remove +those index settings from the on-disk cluster state. The tool takes a +list of index settings as parameters that should be removed, and also supports +wildcard patterns. + +The intended use is: + +* Stop the node +* Run `elasticsearch-node remove-index-settings name-of-index-setting-to-remove` on the node +* Repeat for all nodes +* Start the nodes + [discrete] ==== Removing custom metadata from the cluster state @@ -436,6 +460,37 @@ You can also use wildcards to remove multiple settings, for example using node$ ./bin/elasticsearch-node remove-settings xpack.monitoring.* ---- +[discrete] +==== Removing index settings + +If your indices contain index settings that prevent the cluster +from forming, you can run the following command to remove one +or more index settings. + +[source,txt] +---- +node$ ./bin/elasticsearch-node remove-index-settings index.my_plugin.foo + + WARNING: Elasticsearch MUST be stopped before running this tool. + +You should only run this tool if you have incompatible index settings in the +cluster state that prevent the cluster from forming. +This tool can cause data loss and its use should be your last resort. + +Do you want to proceed? + +Confirm [y/N] y + +Index settings were successfully removed from the cluster state +---- + +You can also use wildcards to remove multiple index settings, for example using + +[source,txt] +---- +node$ ./bin/elasticsearch-node remove-index-settings index.my_plugin.* +---- + [discrete] ==== Removing custom metadata from the cluster state diff --git a/docs/reference/docs/reindex.asciidoc b/docs/reference/docs/reindex.asciidoc index 146b519b05e80..dc27e40ecd90b 100644 --- a/docs/reference/docs/reindex.asciidoc +++ b/docs/reference/docs/reindex.asciidoc @@ -1035,7 +1035,7 @@ ignored, only the host and port are used. For example: [source,yaml] -------------------------------------------------- -reindex.remote.whitelist: "otherhost:9200, another:9200, 127.0.10.*:9200, localhost:*" +reindex.remote.whitelist: [otherhost:9200, another:9200, 127.0.10.*:9200, localhost:*"] -------------------------------------------------- The list of allowed hosts must be configured on any nodes that will coordinate the reindex. diff --git a/docs/reference/esql/esql-across-clusters.asciidoc b/docs/reference/esql/esql-across-clusters.asciidoc index 95278314b0253..6231b4f4f0a69 100644 --- a/docs/reference/esql/esql-across-clusters.asciidoc +++ b/docs/reference/esql/esql-across-clusters.asciidoc @@ -1,6 +1,5 @@ [[esql-cross-clusters]] === Using {esql} across clusters - ++++ Using {esql} across clusters ++++ @@ -11,6 +10,8 @@ preview::["{ccs-cap} for {esql} is in technical preview and may be changed or re With {esql}, you can execute a single query across multiple clusters. +[discrete] +[[esql-ccs-prerequisites]] ==== Prerequisites include::{es-ref-dir}/search/search-your-data/search-across-clusters.asciidoc[tag=ccs-prereqs] @@ -19,9 +20,101 @@ include::{es-ref-dir}/search/search-your-data/search-across-clusters.asciidoc[ta include::{es-ref-dir}/search/search-your-data/search-across-clusters.asciidoc[tag=ccs-proxy-mode] +[discrete] +[[esql-ccs-security-model]] +==== Security model + +{es} supports two security models for cross-cluster search (CCS): + +* <> +* <> + +[TIP] +==== +To check which security model is being used to connect your clusters, run `GET _remote/info`. +If you're using the API key authentication method, you'll see the `"cluster_credentials"` key in the response. +==== + +[discrete] +[[esql-ccs-security-model-certificate]] +===== TLS certificate authentication + +TLS certificate authentication secures remote clusters with mutual TLS. +This could be the preferred model when a single administrator has full control over both clusters. +We generally recommend that roles and their privileges be identical in both clusters. + +Refer to <> for prerequisites and detailed setup instructions. + +[discrete] +[[esql-ccs-security-model-api-key]] +===== API key authentication + +[NOTE] +==== +`ENRICH` is *not supported* in this version when using {esql} with the API key based security model. +==== + +The following information pertains to using {esql} across clusters with the <>. You'll need to follow the steps on that page for the *full setup instructions*. This page only contains additional information specific to {esql}. + +API key based cross-cluster search (CCS) enables more granular control over allowed actions between clusters. +This may be the preferred model when you have different administrators for different clusters and want more control over who can access what data. In this model, cluster administrators must explicitly define the access given to clusters and users. + +You will need to: + +* Create an API key on the *remote cluster* using the <> API or using the {kibana-ref}/api-keys.html[Kibana API keys UI]. +* Add the API key to the keystore on the *local cluster*, as part of the steps in <>. All cross-cluster requests from the local cluster are bound by the API key’s privileges. + +Using {esql} with the API key based security model requires some additional permissions that may not be needed when using the traditional query DSL based search. +The following example API call creates a role that can query remote indices using {esql} when using the API key based security model. + +[source,console] +---- +POST /_security/role/remote1 +{ + "cluster": ["cross_cluster_search"], <1> + "indices": [ + { + "names" : [""], <2> + "privileges": ["read"] + } + ], + "remote_indices": [ <3> + { + "names": [ "logs-*" ], + "privileges": [ "read","read_cross_cluster" ], <4> + "clusters" : ["my_remote_cluster"] <5> + } + ] +} +---- + +<1> The `cross_cluster_search` cluster privilege is required for the _local_ cluster. +<2> Typically, users will have permissions to read both local and remote indices. However, for cases where the role is intended to ONLY search the remote cluster, the `read` permission is still required for the local cluster. To provide read access to the local cluster, but disallow reading any indices in the local cluster, the `names` field may be an empty string. +<3> The indices allowed read access to the remote cluster. The configured <> must also allow this index to be read. +<4> The `read_cross_cluster` privilege is always required when using {esql} across clusters with the API key based security model. +<5> The remote clusters to which these privileges apply. +This remote cluster must be configured with a <> and connected to the remote cluster before the remote index can be queried. +Verify connection using the <> API. + +You will then need a user or API key with the permissions you created above. The following example API call creates a user with the `remote1` role. + +[source,console] +---- +POST /_security/user/remote_user +{ + "password" : "", + "roles" : [ "remote1" ] +} +---- + +Remember that all cross-cluster requests from the local cluster are bound by the cross cluster API key’s privileges, which are controlled by the remote cluster's administrator. + [discrete] [[ccq-remote-cluster-setup]] ==== Remote cluster setup + +Once the security model is configured, you can add remote clusters. + include::{es-ref-dir}/search/search-your-data/search-across-clusters.asciidoc[tag=ccs-remote-cluster-setup] <1> Since `skip_unavailable` was not set on `cluster_three`, it uses @@ -71,13 +164,18 @@ FROM *:my-index-000001 Enrich in {esql} across clusters operates similarly to <>. If the enrich policy and its enrich indices are consistent across all clusters, simply write the enrich command as you would without remote clusters. In this default mode, -{esql} can execute the enrich command on either the querying cluster or the fulfilling +{esql} can execute the enrich command on either the local cluster or the remote clusters, aiming to minimize computation or inter-cluster data transfer. Ensuring that -the policy exists with consistent data on both the querying cluster and the fulfilling +the policy exists with consistent data on both the local cluster and the remote clusters is critical for ES|QL to produce a consistent query result. +[NOTE] +==== +Enrich across clusters is *not supported* in this version when using {esql} with the <>. +==== + In the following example, the enrich with `hosts` policy can be executed on -either the querying cluster or the remote cluster `cluster_one`. +either the local cluster or the remote cluster `cluster_one`. [source,esql] ---- @@ -87,8 +185,8 @@ FROM my-index-000001,cluster_one:my-index-000001 ---- Enrich with an {esql} query against remote clusters only can also happen on -the querying cluster. This means the below query requires the `hosts` enrich -policy to exist on the querying cluster as well. +the local cluster. This means the below query requires the `hosts` enrich +policy to exist on the local cluster as well. [source,esql] ---- @@ -99,10 +197,10 @@ FROM cluster_one:my-index-000001,cluster_two:my-index-000001 [discrete] [[esql-enrich-coordinator]] -==== Enrich with coordinator mode +===== Enrich with coordinator mode {esql} provides the enrich `_coordinator` mode to force {esql} to execute the enrich -command on the querying cluster. This mode should be used when the enrich policy is +command on the local cluster. This mode should be used when the enrich policy is not available on the remote clusters or maintaining consistency of enrich indices across clusters is challenging. @@ -118,21 +216,21 @@ FROM my-index-000001,cluster_one:my-index-000001 [IMPORTANT] ==== Enrich with the `_coordinator` mode usually increases inter-cluster data transfer and -workload on the querying cluster. +workload on the local cluster. ==== [discrete] [[esql-enrich-remote]] -==== Enrich with remote mode +===== Enrich with remote mode {esql} also provides the enrich `_remote` mode to force {esql} to execute the enrich -command independently on each fulfilling cluster where the target indices reside. +command independently on each remote cluster where the target indices reside. This mode is useful for managing different enrich data on each cluster, such as detailed information of hosts for each region where the target (main) indices contain log events from these hosts. In the below example, the `hosts` enrich policy is required to exist on all -fulfilling clusters: the `querying` cluster (as local indices are included), +remote clusters: the `querying` cluster (as local indices are included), the remote cluster `cluster_one`, and `cluster_two`. [source,esql] @@ -157,12 +255,12 @@ FROM my-index-000001,cluster_one:my-index-000001,cluster_two:my-index-000001 [discrete] [[esql-multi-enrich]] -==== Multiple enrich commands +===== Multiple enrich commands You can include multiple enrich commands in the same query with different modes. {esql} will attempt to execute them accordingly. For example, this query performs two enriches, first with the `hosts` policy on any cluster -and then with the `vendors` policy on the querying cluster. +and then with the `vendors` policy on the local cluster. [source,esql] ---- diff --git a/docs/reference/inference/delete-inference.asciidoc b/docs/reference/inference/delete-inference.asciidoc index 89f76e6cef841..dca800c98ca2e 100644 --- a/docs/reference/inference/delete-inference.asciidoc +++ b/docs/reference/inference/delete-inference.asciidoc @@ -43,6 +43,21 @@ The unique identifier of the {infer} endpoint to delete. The type of {infer} task that the model performs. +[discrete] +[[delete-inference-query-parms]] +== {api-query-parms-title} + +`dry_run`:: +(Optional, Boolean) +When `true`, checks the {infer} processors that reference the endpoint and +returns them in a list, but does not deletes the endpoint. Defaults to `false`. + +`force`:: +(Optional, Boolean) +Deletes the endpoint regardless if it's used in an {infer} pipeline or a in a +`semantic_text` field. + + [discrete] [[delete-inference-api-example]] ==== {api-examples-title} diff --git a/docs/reference/modules/cluster/remote-clusters-api-key.asciidoc b/docs/reference/modules/cluster/remote-clusters-api-key.asciidoc index 5f462b14405ba..4aa97ce375d9f 100644 --- a/docs/reference/modules/cluster/remote-clusters-api-key.asciidoc +++ b/docs/reference/modules/cluster/remote-clusters-api-key.asciidoc @@ -63,6 +63,7 @@ information, refer to https://www.elastic.co/subscriptions. NOTE: If a remote cluster is part of an {ess} deployment, it has a valid certificate by default. You can therefore skip steps related to certificates in these instructions. +[[remote-clusters-security-api-key-remote-action]] ===== On the remote cluster // tag::remote-cluster-steps[] @@ -155,6 +156,7 @@ to the indices you want to use for {ccs} or {ccr}. You can use the need it to connect to the remote cluster later. // end::remote-cluster-steps[] +[[remote-clusters-security-api-key-local-actions]] ===== On the local cluster // tag::local-cluster-steps[] diff --git a/muted-tests.yml b/muted-tests.yml index d7d47a48fee23..fd1999d201bf0 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -19,7 +19,8 @@ tests: method: "testGuessIsDayFirstFromLocale" - class: "org.elasticsearch.test.rest.ClientYamlTestSuiteIT" issue: "https://github.com/elastic/elasticsearch/issues/108857" - method: "test {yaml=search/180_locale_dependent_mapping/Test Index and Search locale dependent mappings / dates}" + method: "test {yaml=search/180_locale_dependent_mapping/Test Index and Search locale\ + \ dependent mappings / dates}" - class: "org.elasticsearch.upgrades.SearchStatesIT" issue: "https://github.com/elastic/elasticsearch/issues/108991" method: "testCanMatch" @@ -28,7 +29,8 @@ tests: method: "testTrainedModelInference" - class: "org.elasticsearch.xpack.security.CoreWithSecurityClientYamlTestSuiteIT" issue: "https://github.com/elastic/elasticsearch/issues/109188" - method: "test {yaml=search/180_locale_dependent_mapping/Test Index and Search locale dependent mappings / dates}" + method: "test {yaml=search/180_locale_dependent_mapping/Test Index and Search locale\ + \ dependent mappings / dates}" - class: "org.elasticsearch.xpack.esql.qa.mixed.EsqlClientYamlIT" issue: "https://github.com/elastic/elasticsearch/issues/109189" method: "test {p0=esql/70_locale/Date format with Italian locale}" @@ -44,6 +46,14 @@ tests: - class: org.elasticsearch.analysis.common.CommonAnalysisClientYamlTestSuiteIT method: org.elasticsearch.analysis.common.CommonAnalysisClientYamlTestSuiteIT issue: https://github.com/elastic/elasticsearch/issues/109266 +- class: "org.elasticsearch.index.engine.frozen.FrozenIndexIT" + issue: "https://github.com/elastic/elasticsearch/issues/109315" + method: "testTimestampFieldTypeExposedByAllIndicesServices" +- class: "org.elasticsearch.analysis.common.CommonAnalysisClientYamlTestSuiteIT" + issue: "https://github.com/elastic/elasticsearch/issues/109318" +- class: "org.elasticsearch.upgrades.AggregationsIT" + issue: "https://github.com/elastic/elasticsearch/issues/109322" + method: "testTerms" # Examples: # diff --git a/server/src/internalClusterTest/java/org/elasticsearch/cluster/coordination/RemoveIndexSettingsCommandIT.java b/server/src/internalClusterTest/java/org/elasticsearch/cluster/coordination/RemoveIndexSettingsCommandIT.java new file mode 100644 index 0000000000000..a5e445270ccc4 --- /dev/null +++ b/server/src/internalClusterTest/java/org/elasticsearch/cluster/coordination/RemoveIndexSettingsCommandIT.java @@ -0,0 +1,162 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ +package org.elasticsearch.cluster.coordination; + +import joptsimple.OptionSet; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.cli.MockTerminal; +import org.elasticsearch.cli.ProcessInfo; +import org.elasticsearch.cli.UserException; +import org.elasticsearch.common.settings.Setting; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.CollectionUtils; +import org.elasticsearch.env.Environment; +import org.elasticsearch.env.TestEnvironment; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.test.ESIntegTestCase; + +import java.util.Collection; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.not; + +@ESIntegTestCase.ClusterScope(scope = ESIntegTestCase.Scope.TEST, numDataNodes = 0, autoManageMasterNodes = false) +public class RemoveIndexSettingsCommandIT extends ESIntegTestCase { + + static final Setting FOO = Setting.intSetting("index.foo", 1, Setting.Property.IndexScope, Setting.Property.Dynamic); + static final Setting BAR = Setting.intSetting("index.bar", 2, Setting.Property.IndexScope, Setting.Property.Final); + + public static class ExtraSettingsPlugin extends Plugin { + @Override + public List> getSettings() { + return List.of(FOO, BAR); + } + } + + @Override + protected Collection> nodePlugins() { + return CollectionUtils.appendToCopy(super.nodePlugins(), ExtraSettingsPlugin.class); + } + + public void testRemoveSettingsAbortedByUser() throws Exception { + internalCluster().setBootstrapMasterNodeIndex(0); + var node = internalCluster().startNode(); + createIndex("test-index", Settings.builder().put(FOO.getKey(), 101).put(BAR.getKey(), 102).build()); + ensureYellow("test-index"); + Settings dataPathSettings = internalCluster().dataPathSettings(node); + ensureStableCluster(1); + internalCluster().stopRandomDataNode(); + + Settings nodeSettings = Settings.builder().put(internalCluster().getDefaultSettings()).put(dataPathSettings).build(); + ElasticsearchException error = expectThrows( + ElasticsearchException.class, + () -> removeIndexSettings(TestEnvironment.newEnvironment(nodeSettings), true, "index.foo") + ); + assertThat(error.getMessage(), equalTo(ElasticsearchNodeCommand.ABORTED_BY_USER_MSG)); + internalCluster().startNode(nodeSettings); + } + + public void testRemoveSettingsSuccessful() throws Exception { + internalCluster().setBootstrapMasterNodeIndex(0); + var node = internalCluster().startNode(); + Settings dataPathSettings = internalCluster().dataPathSettings(node); + + int numIndices = randomIntBetween(1, 10); + int[] barValues = new int[numIndices]; + for (int i = 0; i < numIndices; i++) { + String index = "test-index-" + i; + barValues[i] = between(1, 1000); + createIndex(index, Settings.builder().put(FOO.getKey(), between(1, 1000)).put(BAR.getKey(), barValues[i]).build()); + } + int moreIndices = randomIntBetween(1, 10); + for (int i = 0; i < moreIndices; i++) { + createIndex("more-index-" + i, Settings.EMPTY); + } + internalCluster().stopNode(node); + + Environment environment = TestEnvironment.newEnvironment( + Settings.builder().put(internalCluster().getDefaultSettings()).put(dataPathSettings).build() + ); + + MockTerminal terminal = removeIndexSettings(environment, false, "index.foo"); + assertThat(terminal.getOutput(), containsString(RemoveIndexSettingsCommand.SETTINGS_REMOVED_MSG)); + for (int i = 0; i < numIndices; i++) { + assertThat(terminal.getOutput(), containsString("Index setting [index.foo] will be removed from index [[test-index-" + i)); + } + for (int i = 0; i < moreIndices; i++) { + assertThat(terminal.getOutput(), not(containsString("Index setting [index.foo] will be removed from index [[more-index-" + i))); + } + Settings nodeSettings = Settings.builder().put(internalCluster().getDefaultSettings()).put(dataPathSettings).build(); + internalCluster().startNode(nodeSettings); + + Map getIndexSettings = client().admin().indices().prepareGetSettings("test-index-*").get().getIndexToSettings(); + for (int i = 0; i < numIndices; i++) { + String index = "test-index-" + i; + Settings indexSettings = getIndexSettings.get(index); + assertFalse(indexSettings.hasValue("index.foo")); + assertThat(indexSettings.get("index.bar"), equalTo(Integer.toString(barValues[i]))); + } + getIndexSettings = client().admin().indices().prepareGetSettings("more-index-*").get().getIndexToSettings(); + for (int i = 0; i < moreIndices; i++) { + assertNotNull(getIndexSettings.get("more-index-" + i)); + } + } + + public void testSettingDoesNotMatch() throws Exception { + internalCluster().setBootstrapMasterNodeIndex(0); + var node = internalCluster().startNode(); + createIndex("test-index", Settings.builder().put(FOO.getKey(), 101).put(BAR.getKey(), 102).build()); + ensureYellow("test-index"); + Settings dataPathSettings = internalCluster().dataPathSettings(node); + ensureStableCluster(1); + internalCluster().stopRandomDataNode(); + + Settings nodeSettings = Settings.builder().put(internalCluster().getDefaultSettings()).put(dataPathSettings).build(); + UserException error = expectThrows( + UserException.class, + () -> removeIndexSettings(TestEnvironment.newEnvironment(nodeSettings), true, "index.not_foo") + ); + assertThat(error.getMessage(), containsString("No index setting matching [index.not_foo] were found on this node")); + internalCluster().startNode(nodeSettings); + } + + private MockTerminal executeCommand(ElasticsearchNodeCommand command, Environment environment, boolean abort, String... args) + throws Exception { + final MockTerminal terminal = MockTerminal.create(); + final OptionSet options = command.getParser().parse(args); + final ProcessInfo processInfo = new ProcessInfo(Map.of(), Map.of(), createTempDir()); + final String input; + + if (abort) { + input = randomValueOtherThanMany(c -> c.equalsIgnoreCase("y"), () -> randomAlphaOfLength(1)); + } else { + input = randomBoolean() ? "y" : "Y"; + } + + terminal.addTextInput(input); + + try { + command.execute(terminal, options, environment, processInfo); + } finally { + assertThat(terminal.getOutput(), containsString(ElasticsearchNodeCommand.STOP_WARNING_MSG)); + } + + return terminal; + } + + private MockTerminal removeIndexSettings(Environment environment, boolean abort, String... args) throws Exception { + final MockTerminal terminal = executeCommand(new RemoveIndexSettingsCommand(), environment, abort, args); + assertThat(terminal.getOutput(), containsString(RemoveIndexSettingsCommand.CONFIRMATION_MSG)); + assertThat(terminal.getOutput(), containsString(RemoveIndexSettingsCommand.SETTINGS_REMOVED_MSG)); + return terminal; + } +} diff --git a/server/src/main/java/org/elasticsearch/TransportVersions.java b/server/src/main/java/org/elasticsearch/TransportVersions.java index e5428497c136d..07579161a85c8 100644 --- a/server/src/main/java/org/elasticsearch/TransportVersions.java +++ b/server/src/main/java/org/elasticsearch/TransportVersions.java @@ -181,6 +181,8 @@ static TransportVersion def(int id) { public static final TransportVersion ML_INFERENCE_GOOGLE_AI_STUDIO_COMPLETION_ADDED = def(8_672_00_0); public static final TransportVersion WATCHER_REQUEST_TIMEOUTS = def(8_673_00_0); public static final TransportVersion ML_INFERENCE_ENHANCE_DELETE_ENDPOINT = def(8_674_00_0); + public static final TransportVersion ML_INFERENCE_GOOGLE_AI_STUDIO_EMBEDDINGS_ADDED = def(8_675_00_0); + public static final TransportVersion ADD_MISTRAL_EMBEDDINGS_INFERENCE = def(8_676_00_0); /* * STOP! READ THIS FIRST! No, really, diff --git a/server/src/main/java/org/elasticsearch/cluster/coordination/NodeToolCli.java b/server/src/main/java/org/elasticsearch/cluster/coordination/NodeToolCli.java index 58f37ec220669..81044e8e3ad51 100644 --- a/server/src/main/java/org/elasticsearch/cluster/coordination/NodeToolCli.java +++ b/server/src/main/java/org/elasticsearch/cluster/coordination/NodeToolCli.java @@ -20,6 +20,7 @@ class NodeToolCli extends MultiCommand { subcommands.put("detach-cluster", new DetachClusterCommand()); subcommands.put("override-version", new OverrideNodeVersionCommand()); subcommands.put("remove-settings", new RemoveSettingsCommand()); + subcommands.put("remove-index-settings", new RemoveIndexSettingsCommand()); subcommands.put("remove-customs", new RemoveCustomsCommand()); } } diff --git a/server/src/main/java/org/elasticsearch/cluster/coordination/RemoveIndexSettingsCommand.java b/server/src/main/java/org/elasticsearch/cluster/coordination/RemoveIndexSettingsCommand.java new file mode 100644 index 0000000000000..c6514f9cb4a0b --- /dev/null +++ b/server/src/main/java/org/elasticsearch/cluster/coordination/RemoveIndexSettingsCommand.java @@ -0,0 +1,100 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0 and the Server Side Public License, v 1; you may not use this file except + * in compliance with, at your election, the Elastic License 2.0 or the Server + * Side Public License, v 1. + */ +package org.elasticsearch.cluster.coordination; + +import joptsimple.OptionSet; +import joptsimple.OptionSpec; + +import org.elasticsearch.cli.ExitCodes; +import org.elasticsearch.cli.Terminal; +import org.elasticsearch.cli.UserException; +import org.elasticsearch.cluster.ClusterState; +import org.elasticsearch.cluster.metadata.IndexMetadata; +import org.elasticsearch.cluster.metadata.Metadata; +import org.elasticsearch.common.regex.Regex; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Tuple; +import org.elasticsearch.env.Environment; +import org.elasticsearch.gateway.PersistedClusterStateService; + +import java.io.IOException; +import java.nio.file.Path; +import java.util.List; + +public class RemoveIndexSettingsCommand extends ElasticsearchNodeCommand { + + static final String SETTINGS_REMOVED_MSG = "Index settings were successfully removed from the cluster state"; + static final String CONFIRMATION_MSG = DELIMITER + + "\n" + + "You should only run this tool if you have incompatible index settings in the\n" + + "cluster state that prevent the cluster from forming.\n" + + "This tool can cause data loss and its use should be your last resort.\n" + + "\n" + + "Do you want to proceed?\n"; + + private final OptionSpec arguments; + + public RemoveIndexSettingsCommand() { + super("Removes index settings from the cluster state"); + arguments = parser.nonOptions("index setting names"); + } + + @Override + protected void processDataPaths(Terminal terminal, Path[] dataPaths, OptionSet options, Environment env) throws IOException, + UserException { + final List settingsToRemove = arguments.values(options); + if (settingsToRemove.isEmpty()) { + throw new UserException(ExitCodes.USAGE, "Must supply at least one index setting to remove"); + } + + final PersistedClusterStateService persistedClusterStateService = createPersistedClusterStateService(env.settings(), dataPaths); + + terminal.println(Terminal.Verbosity.VERBOSE, "Loading cluster state"); + final Tuple termAndClusterState = loadTermAndClusterState(persistedClusterStateService, env); + final ClusterState oldClusterState = termAndClusterState.v2(); + final Metadata.Builder newMetadataBuilder = Metadata.builder(oldClusterState.metadata()); + int changes = 0; + for (IndexMetadata indexMetadata : oldClusterState.metadata()) { + Settings oldSettings = indexMetadata.getSettings(); + Settings.Builder newSettings = Settings.builder().put(oldSettings); + boolean removed = false; + for (String settingToRemove : settingsToRemove) { + for (String settingKey : oldSettings.keySet()) { + if (Regex.simpleMatch(settingToRemove, settingKey)) { + terminal.println( + "Index setting [" + settingKey + "] will be removed from index [" + indexMetadata.getIndex() + "]" + ); + newSettings.remove(settingKey); + removed = true; + } + } + } + if (removed) { + newMetadataBuilder.put(IndexMetadata.builder(indexMetadata).settings(newSettings)); + changes++; + } + } + if (changes == 0) { + throw new UserException(ExitCodes.USAGE, "No index setting matching " + settingsToRemove + " were found on this node"); + } + + final ClusterState newClusterState = ClusterState.builder(oldClusterState).metadata(newMetadataBuilder).build(); + terminal.println( + Terminal.Verbosity.VERBOSE, + "[old cluster state = " + oldClusterState + ", new cluster state = " + newClusterState + "]" + ); + + confirm(terminal, CONFIRMATION_MSG); + + try (PersistedClusterStateService.Writer writer = persistedClusterStateService.createWriter()) { + writer.writeFullStateAndCommit(termAndClusterState.v1(), newClusterState); + } + + terminal.println(SETTINGS_REMOVED_MSG); + } +} diff --git a/server/src/main/java/org/elasticsearch/index/mapper/DocumentMapper.java b/server/src/main/java/org/elasticsearch/index/mapper/DocumentMapper.java index 1b07d93295fe1..0136175cc6391 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/DocumentMapper.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/DocumentMapper.java @@ -126,7 +126,13 @@ public void validate(IndexSettings settings, boolean checkLimits) { * Build an empty source loader to validate that the mapping is compatible * with the source loading strategy declared on the source field mapper. */ - sourceMapper().newSourceLoader(mapping(), mapperMetrics.sourceFieldMetrics()); + try { + sourceMapper().newSourceLoader(mapping(), mapperMetrics.sourceFieldMetrics()); + } catch (IllegalArgumentException e) { + mapperMetrics.sourceFieldMetrics().recordSyntheticSourceIncompatibleMapping(); + throw e; + } + if (settings.getIndexSortConfig().hasIndexSort() && mappers().nestedLookup() != NestedLookup.EMPTY) { throw new IllegalArgumentException("cannot have nested fields when index sort is activated"); } diff --git a/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMetrics.java b/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMetrics.java index 0e6ce79fd2170..eaccdbc9e44ce 100644 --- a/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMetrics.java +++ b/server/src/main/java/org/elasticsearch/index/mapper/SourceFieldMetrics.java @@ -9,6 +9,7 @@ package org.elasticsearch.index.mapper; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.telemetry.metric.LongCounter; import org.elasticsearch.telemetry.metric.LongHistogram; import org.elasticsearch.telemetry.metric.MeterRegistry; @@ -21,10 +22,12 @@ public class SourceFieldMetrics { public static final SourceFieldMetrics NOOP = new SourceFieldMetrics(MeterRegistry.NOOP, () -> 0); public static final String SYNTHETIC_SOURCE_LOAD_LATENCY = "es.mapper.synthetic_source.load.latency.histogram"; + public static final String SYNTHETIC_SOURCE_INCOMPATIBLE_MAPPING = "es.mapper.synthetic_source.incompatible_mapping.total"; private final LongSupplier relativeTimeSupplier; private final LongHistogram syntheticSourceLoadLatency; + private final LongCounter syntheticSourceIncompatibleMapping; public SourceFieldMetrics(MeterRegistry meterRegistry, LongSupplier relativeTimeSupplier) { this.syntheticSourceLoadLatency = meterRegistry.registerLongHistogram( @@ -32,6 +35,11 @@ public SourceFieldMetrics(MeterRegistry meterRegistry, LongSupplier relativeTime "Time it takes to load fields and construct synthetic source", "ms" ); + this.syntheticSourceIncompatibleMapping = meterRegistry.registerLongCounter( + SYNTHETIC_SOURCE_INCOMPATIBLE_MAPPING, + "Number of create/update index operations using mapping not compatible with synthetic source", + "count" + ); this.relativeTimeSupplier = relativeTimeSupplier; } @@ -42,4 +50,8 @@ public LongSupplier getRelativeTimeSupplier() { public void recordSyntheticSourceLoadLatency(TimeValue value) { this.syntheticSourceLoadLatency.record(value.millis()); } + + public void recordSyntheticSourceIncompatibleMapping() { + this.syntheticSourceIncompatibleMapping.increment(); + } } diff --git a/server/src/test/java/org/elasticsearch/index/mapper/SourceLoaderTelemetryTests.java b/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMetricsTests.java similarity index 70% rename from server/src/test/java/org/elasticsearch/index/mapper/SourceLoaderTelemetryTests.java rename to server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMetricsTests.java index 1c88cbb0d8592..f569a69246d9f 100644 --- a/server/src/test/java/org/elasticsearch/index/mapper/SourceLoaderTelemetryTests.java +++ b/server/src/test/java/org/elasticsearch/index/mapper/SourceFieldMetricsTests.java @@ -20,7 +20,7 @@ import java.util.Collection; import java.util.List; -public class SourceLoaderTelemetryTests extends MapperServiceTestCase { +public class SourceFieldMetricsTests extends MapperServiceTestCase { private final TestTelemetryPlugin telemetryPlugin = new TestTelemetryPlugin(); @Override @@ -34,8 +34,8 @@ public void testFieldHasValue() {} @Override public void testFieldHasValueWithEmptyFieldInfos() {} - public void testSyntheticSourceTelemetry() throws IOException { - var mapping = syntheticSourceMapping(b -> { b.startObject("kwd").field("type", "keyword").endObject(); }); + public void testSyntheticSourceLoadLatency() throws IOException { + var mapping = syntheticSourceMapping(b -> b.startObject("kwd").field("type", "keyword").endObject()); var mapper = createDocumentMapper(mapping); try (Directory directory = newDirectory()) { @@ -58,4 +58,15 @@ public void testSyntheticSourceTelemetry() throws IOException { // test implementation of time provider always has a gap of 1 between values assertEquals(measurements.get(0).getLong(), 1); } + + public void testSyntheticSourceIncompatibleMapping() throws IOException { + var mapping = syntheticSourceMapping(b -> b.startObject("kwd").field("type", "text").field("store", "false").endObject()); + var mapperMetrics = createTestMapperMetrics(); + var mapperService = new TestMapperServiceBuilder().mapperMetrics(mapperMetrics).build(); + assertThrows(IllegalArgumentException.class, () -> withMapping(mapperService, mapping)); + + var measurements = telemetryPlugin.getLongCounterMeasurement(SourceFieldMetrics.SYNTHETIC_SOURCE_INCOMPATIBLE_MAPPING); + assertEquals(1, measurements.size()); + assertEquals(measurements.get(0).getLong(), 1); + } } diff --git a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java index 1d08f28d47a88..388d8d6fa6ffd 100644 --- a/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/index/mapper/MapperServiceTestCase.java @@ -61,6 +61,7 @@ import org.elasticsearch.plugins.TelemetryPlugin; import org.elasticsearch.plugins.internal.DocumentSizeObserver; import org.elasticsearch.script.Script; +import org.elasticsearch.script.ScriptCompiler; import org.elasticsearch.script.ScriptContext; import org.elasticsearch.search.aggregations.Aggregator; import org.elasticsearch.search.aggregations.support.AggregationContext; @@ -198,41 +199,85 @@ protected final MapperService createMapperService( BooleanSupplier idFieldDataEnabled, XContentBuilder mapping ) throws IOException { - MapperService mapperService = createMapperService(version, settings, idFieldDataEnabled); - merge(mapperService, mapping); - return mapperService; + return withMapping(mapperService, mapping); } protected final MapperService createMapperService(IndexVersion version, Settings settings, BooleanSupplier idFieldDataEnabled) { - IndexSettings indexSettings = createIndexSettings(version, settings); - MapperRegistry mapperRegistry = new IndicesModule( - getPlugins().stream().filter(p -> p instanceof MapperPlugin).map(p -> (MapperPlugin) p).collect(toList()) - ).getMapperRegistry(); + return new TestMapperServiceBuilder().indexVersion(version).settings(settings).idFieldDataEnabled(idFieldDataEnabled).build(); + } - SimilarityService similarityService = new SimilarityService(indexSettings, null, Map.of()); - BitsetFilterCache bitsetFilterCache = new BitsetFilterCache(indexSettings, new BitsetFilterCache.Listener() { - @Override - public void onCache(ShardId shardId, Accountable accountable) {} + protected final MapperService withMapping(MapperService mapperService, XContentBuilder mapping) throws IOException { + merge(mapperService, mapping); + return mapperService; + }; + + protected class TestMapperServiceBuilder { + private IndexVersion indexVersion; + private Settings settings; + private BooleanSupplier idFieldDataEnabled; + private ScriptCompiler scriptCompiler; + private MapperMetrics mapperMetrics; + + public TestMapperServiceBuilder() { + indexVersion = getVersion(); + settings = getIndexSettings(); + idFieldDataEnabled = () -> true; + scriptCompiler = MapperServiceTestCase.this::compileScript; + mapperMetrics = MapperMetrics.NOOP; + } - @Override - public void onRemoval(ShardId shardId, Accountable accountable) {} - }); - return new MapperService( - () -> TransportVersion.current(), - indexSettings, - createIndexAnalyzers(indexSettings), - parserConfig(), - similarityService, - mapperRegistry, - () -> { - throw new UnsupportedOperationException(); - }, - indexSettings.getMode().buildIdFieldMapper(idFieldDataEnabled), - this::compileScript, - bitsetFilterCache::getBitSetProducer, - MapperMetrics.NOOP - ); + public TestMapperServiceBuilder indexVersion(IndexVersion indexVersion) { + this.indexVersion = indexVersion; + return this; + } + + public TestMapperServiceBuilder settings(Settings settings) { + this.settings = settings; + return this; + } + + public TestMapperServiceBuilder idFieldDataEnabled(BooleanSupplier idFieldDataEnabled) { + this.idFieldDataEnabled = idFieldDataEnabled; + return this; + } + + public TestMapperServiceBuilder mapperMetrics(MapperMetrics mapperMetrics) { + this.mapperMetrics = mapperMetrics; + return this; + } + + public MapperService build() { + IndexSettings indexSettings = createIndexSettings(indexVersion, settings); + SimilarityService similarityService = new SimilarityService(indexSettings, null, Map.of()); + MapperRegistry mapperRegistry = new IndicesModule( + getPlugins().stream().filter(p -> p instanceof MapperPlugin).map(p -> (MapperPlugin) p).collect(toList()) + ).getMapperRegistry(); + + BitsetFilterCache bitsetFilterCache = new BitsetFilterCache(indexSettings, new BitsetFilterCache.Listener() { + @Override + public void onCache(ShardId shardId, Accountable accountable) {} + + @Override + public void onRemoval(ShardId shardId, Accountable accountable) {} + }); + + return new MapperService( + () -> TransportVersion.current(), + indexSettings, + createIndexAnalyzers(indexSettings), + parserConfig(), + similarityService, + mapperRegistry, + () -> { + throw new UnsupportedOperationException(); + }, + indexSettings.getMode().buildIdFieldMapper(idFieldDataEnabled), + scriptCompiler, + bitsetFilterCache::getBitSetProducer, + mapperMetrics + ); + } } /** diff --git a/x-pack/plugin/core/build.gradle b/x-pack/plugin/core/build.gradle index fb35b34fd4dfd..f0a00b7aa7e75 100644 --- a/x-pack/plugin/core/build.gradle +++ b/x-pack/plugin/core/build.gradle @@ -180,3 +180,16 @@ if (BuildParams.inFipsJvm) { // Test clusters run with security disabled tasks.named("javaRestTest").configure { enabled = false } } + +//this specific test requires a test only system property to be set, so we run it in a different JVM via a separate task +tasks.register('testAutomatonPatterns', Test) { + include '**/AutomatonPatternsTests.class' + systemProperty 'tests.automaton.record.patterns', 'true' + testClassesDirs = sourceSets.test.output.classesDirs + classpath = sourceSets.test.runtimeClasspath +} + +tasks.named('test').configure { + exclude '**/AutomatonPatternsTests.class' + dependsOn testAutomatonPatterns //to ensure testAutomatonPatterns are run with the test task +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/support/Automatons.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/support/Automatons.java index f601aa144aa00..a6347d8b7ec77 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/support/Automatons.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/security/support/Automatons.java @@ -23,12 +23,16 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; import java.util.List; +import java.util.Locale; +import java.util.Map; import java.util.Set; import java.util.concurrent.ExecutionException; import java.util.function.Function; import java.util.function.Predicate; +import java.util.stream.Collectors; import static org.apache.lucene.util.automaton.Operations.DEFAULT_DETERMINIZE_WORK_LIMIT; import static org.apache.lucene.util.automaton.Operations.concatenate; @@ -69,6 +73,10 @@ public final class Automatons { static final char WILDCARD_CHAR = '?'; // Char equality with support for wildcards static final char WILDCARD_ESCAPE = '\\'; // Escape character + // for testing only -Dtests.jvm.argline="-Dtests.automaton.record.patterns=true" + public static boolean recordPatterns = System.getProperty("tests.automaton.record.patterns", "false").equals("true"); + private static final Map> patternsMap = new HashMap<>(); + private Automatons() {} /** @@ -87,10 +95,13 @@ public static Automaton patterns(Collection patterns) { return EMPTY; } if (cache == null) { - return buildAutomaton(patterns); + return maybeRecordPatterns(buildAutomaton(patterns), patterns); } else { try { - return cache.computeIfAbsent(Sets.newHashSet(patterns), p -> buildAutomaton((Set) p)); + return cache.computeIfAbsent( + Sets.newHashSet(patterns), + p -> maybeRecordPatterns(buildAutomaton((Set) p), patterns) + ); } catch (ExecutionException e) { throw unwrapCacheException(e); } @@ -338,4 +349,23 @@ public static void addSettings(List> settingsList) { settingsList.add(CACHE_SIZE); settingsList.add(CACHE_TTL); } + + private static Automaton maybeRecordPatterns(Automaton automaton, Collection patterns) { + if (recordPatterns) { + patternsMap.put( + automaton, + patterns.stream().map(String::trim).map(s -> s.toLowerCase(Locale.ROOT)).sorted().collect(Collectors.toList()) + ); + } + return automaton; + } + + // test only + static List getPatterns(Automaton automaton) { + if (recordPatterns) { + return patternsMap.get(automaton); + } else { + throw new IllegalArgumentException("recordPatterns is set to false"); + } + } } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/privilege/IndexPrivilegeTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/privilege/IndexPrivilegeTests.java index b05f7065ff63c..265714ee6ea16 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/privilege/IndexPrivilegeTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/privilege/IndexPrivilegeTests.java @@ -144,4 +144,5 @@ public void testCrossClusterReplicationPrivileges() { ); assertThat(Operations.subsetOf(crossClusterReplicationInternal.automaton, IndexPrivilege.get(Set.of("all")).automaton), is(true)); } + } diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/support/AutomatonPatternsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/support/AutomatonPatternsTests.java new file mode 100644 index 0000000000000..1539651b1aed6 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/support/AutomatonPatternsTests.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.security.support; + +import org.apache.lucene.util.automaton.Automaton; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.security.authz.privilege.IndexPrivilege; + +import java.util.Arrays; +import java.util.Locale; +import java.util.Set; + +import static org.elasticsearch.xpack.core.security.action.apikey.CrossClusterApiKeyRoleDescriptorBuilder.CCR_INDICES_PRIVILEGE_NAMES; +import static org.elasticsearch.xpack.core.security.action.apikey.CrossClusterApiKeyRoleDescriptorBuilder.CCS_INDICES_PRIVILEGE_NAMES; + +public class AutomatonPatternsTests extends ESTestCase { + + /** + * RCS 2.0 allows a single API key to define "replication" and "search" blocks. If both are defined, this results in an API key with 2 + * sets of indices permissions. Due to the way API keys (and roles) work across the multiple index permission, the set of index + * patterns allowed are effectively the most generous of the sets of index patterns since the index patterns are OR'ed together. For + * example, `foo` OR `*` results in access to `*`. So, if you have "search" access defined as `foo`, but replication access defined + * as `*`, the API key effectively allows access to index pattern `*`. This means that the access for API keys that define both + * "search" and "replication", the action names used are the primary means by which we can constrain CCS to the set of "search" indices + * as well as how we constrain CCR to the set "replication" indices. For example, if "replication" ever allowed access to + * `indices:data/read/get` for `*` , then the "replication" permissions would effectively enable users of CCS to get documents, + * even if "search" is never defined in the RCS 2.0 API key. This obviously is not desirable and in practice when both "search" and + * "replication" are defined the isolation between CCS and CCR is only achieved because the action names for the workflows do not + * overlap. This test helps to ensure that the actions names used for RCS 2.0 do not bleed over between search and replication. + */ + public void testRemoteClusterPrivsDoNotOverlap() { + + // check that the action patterns for remote CCS are not allowed by remote CCR privileges + Arrays.stream(CCS_INDICES_PRIVILEGE_NAMES).forEach(ccsPrivilege -> { + Automaton ccsAutomaton = IndexPrivilege.get(Set.of(ccsPrivilege)).getAutomaton(); + Automatons.getPatterns(ccsAutomaton).forEach(ccsPattern -> { + // emulate an action name that could be allowed by a CCS privilege + String actionName = ccsPattern.replaceAll("\\*", randomAlphaOfLengthBetween(1, 8)); + Arrays.stream(CCR_INDICES_PRIVILEGE_NAMES).forEach(ccrPrivileges -> { + String errorMessage = String.format( + Locale.ROOT, + "CCR privilege \"%s\" allows CCS action \"%s\". This could result in an " + + "accidental bleeding of permission between RCS 2.0's search and replication index permissions", + ccrPrivileges, + ccsPattern + ); + assertFalse(errorMessage, IndexPrivilege.get(Set.of(ccrPrivileges)).predicate().test(actionName)); + }); + }); + }); + + // check that the action patterns for remote CCR are not allowed by remote CCS privileges + Arrays.stream(CCR_INDICES_PRIVILEGE_NAMES).forEach(ccrPrivilege -> { + Automaton ccrAutomaton = IndexPrivilege.get(Set.of(ccrPrivilege)).getAutomaton(); + Automatons.getPatterns(ccrAutomaton).forEach(ccrPattern -> { + // emulate an action name that could be allowed by a CCR privilege + String actionName = ccrPattern.replaceAll("\\*", randomAlphaOfLengthBetween(1, 8)); + Arrays.stream(CCS_INDICES_PRIVILEGE_NAMES).forEach(ccsPrivileges -> { + if ("indices:data/read/xpack/ccr/shard_changes*".equals(ccrPattern)) { + // do nothing, this action is only applicable to CCR workflows and is a moot point if CCS technically has + // access to the index pattern for this action granted by CCR + } else { + String errorMessage = String.format( + Locale.ROOT, + "CCS privilege \"%s\" allows CCR action \"%s\". This could result in an accidental bleeding of " + + "permission between RCS 2.0's search and replication index permissions", + ccsPrivileges, + ccrPattern + ); + assertFalse(errorMessage, IndexPrivilege.get(Set.of(ccsPrivileges)).predicate().test(actionName)); + } + }); + }); + }); + } +} diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Attribute.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Attribute.java index 0736ad3d1e296..5326825ec1105 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Attribute.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Attribute.java @@ -6,6 +6,8 @@ */ package org.elasticsearch.xpack.esql.core.expression; +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.core.Tuple; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; @@ -28,7 +30,11 @@ * is a named expression (an {@code Alias} will be created automatically for it). * The rest are not as they are not part of the projection and thus are not part of the derived table. */ -public abstract class Attribute extends NamedExpression { +public abstract class Attribute extends NamedExpression implements NamedWriteable { + public static List getNamedWriteables() { + // TODO add UnsupportedAttribute when these are moved to the same project + return List.of(FieldAttribute.ENTRY, MetadataAttribute.ENTRY, ReferenceAttribute.ENTRY); + } // empty - such as a top level attribute in SELECT cause // present - table name or a table name alias diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/EmptyAttribute.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/EmptyAttribute.java index 56e5c65b179f6..7a724eaa2be65 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/EmptyAttribute.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/EmptyAttribute.java @@ -7,22 +7,34 @@ package org.elasticsearch.xpack.esql.core.expression; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.DataTypes; import org.elasticsearch.xpack.esql.core.util.StringUtils; +import java.io.IOException; + /** * Marker for optional attributes. Acting as a dummy placeholder to avoid using null * in the tree (which is not allowed). */ public class EmptyAttribute extends Attribute { - public EmptyAttribute(Source source) { super(source, StringUtils.EMPTY, null, null); } + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException("doesn't escape the node"); + } + + @Override + public String getWriteableName() { + throw new UnsupportedOperationException("doesn't escape the node"); + } + @Override protected Attribute clone( Source source, diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/FieldAttribute.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/FieldAttribute.java index 9ee9732b542a0..35fe402035f69 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/FieldAttribute.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/FieldAttribute.java @@ -7,12 +7,17 @@ package org.elasticsearch.xpack.esql.core.expression; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.EsField; +import org.elasticsearch.xpack.esql.core.util.PlanStreamInput; import org.elasticsearch.xpack.esql.core.util.StringUtils; +import java.io.IOException; import java.util.Objects; /** @@ -24,6 +29,11 @@ * - nestedParent - if nested, what's the parent (which might not be the immediate one) */ public class FieldAttribute extends TypedAttribute { + static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Attribute.class, + "FieldAttribute", + FieldAttribute::new + ); private final FieldAttribute parent; private final String path; @@ -67,6 +77,47 @@ public FieldAttribute( this.field = field; } + @SuppressWarnings("unchecked") + public FieldAttribute(StreamInput in) throws IOException { + /* + * The funny casting dance with `` and `(S) in` is required + * because we're in esql-core here and the real PlanStreamInput is in + * esql-proper. And because NamedWriteableRegistry.Entry needs StreamInput, + * not a PlanStreamInput. And we need PlanStreamInput to handle Source + * and NameId. This should become a hard cast when we move everything out + * of esql-core. + */ + this( + Source.readFrom((S) in), + in.readOptionalWriteable(FieldAttribute::new), + in.readString(), + DataType.readFrom(in), + in.readNamedWriteable(EsField.class), + in.readOptionalString(), + in.readEnum(Nullability.class), + NameId.readFrom((S) in), + in.readBoolean() + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + out.writeOptionalWriteable(parent); + out.writeString(name()); + dataType().writeTo(out); + out.writeNamedWriteable(field); + out.writeOptionalString(qualifier()); + out.writeEnum(nullable()); + id().writeTo(out); + out.writeBoolean(synthetic()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected NodeInfo info() { return NodeInfo.create(this, FieldAttribute::new, parent, name(), dataType(), field, qualifier(), nullable(), id(), synthetic()); diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MetadataAttribute.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MetadataAttribute.java index e6777a9ab4bb4..9cbee26f443ba 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MetadataAttribute.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/MetadataAttribute.java @@ -7,6 +7,9 @@ package org.elasticsearch.xpack.esql.core.expression; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.core.Tuple; import org.elasticsearch.index.mapper.IdFieldMapper; import org.elasticsearch.index.mapper.IgnoredFieldMapper; @@ -15,12 +18,20 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.DataTypes; +import org.elasticsearch.xpack.esql.core.util.PlanStreamInput; +import java.io.IOException; import java.util.Map; +import java.util.Objects; import static org.elasticsearch.core.Tuple.tuple; public class MetadataAttribute extends TypedAttribute { + static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Attribute.class, + "MetadataAttribute", + MetadataAttribute::new + ); private static final Map> ATTRIBUTES_MAP = Map.of( "_version", @@ -55,6 +66,45 @@ public MetadataAttribute(Source source, String name, DataType dataType, boolean this(source, name, dataType, null, Nullability.TRUE, null, false, searchable); } + @SuppressWarnings("unchecked") + public MetadataAttribute(StreamInput in) throws IOException { + /* + * The funny casting dance with `` and `(S) in` is required + * because we're in esql-core here and the real PlanStreamInput is in + * esql-proper. And because NamedWriteableRegistry.Entry needs StreamInput, + * not a PlanStreamInput. And we need PlanStreamInput to handle Source + * and NameId. This should become a hard cast when we move everything out + * of esql-core. + */ + this( + Source.readFrom((S) in), + in.readString(), + DataType.readFrom(in), + in.readOptionalString(), + in.readEnum(Nullability.class), + NameId.readFrom((S) in), + in.readBoolean(), + in.readBoolean() + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + out.writeString(name()); + dataType().writeTo(out); + out.writeOptionalString(qualifier()); + out.writeEnum(nullable()); + id().writeTo(out); + out.writeBoolean(synthetic()); + out.writeBoolean(searchable); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected MetadataAttribute clone( Source source, @@ -99,4 +149,18 @@ public static DataType dataType(String name) { public static boolean isSupported(String name) { return ATTRIBUTES_MAP.containsKey(name); } + + @Override + public boolean equals(Object obj) { + if (false == super.equals(obj)) { + return false; + } + MetadataAttribute other = (MetadataAttribute) obj; + return searchable == other.searchable; + } + + @Override + public int hashCode() { + return Objects.hash(super.hashCode(), searchable); + } } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/NameId.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/NameId.java index 2aa70397075d0..d2d01857a1f73 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/NameId.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/NameId.java @@ -55,6 +55,11 @@ public String toString() { } public static NameId readFrom(S in) throws IOException { + /* + * The funny typing dance with `` is required we're in esql-core + * here and the real PlanStreamInput is in esql-proper. And we need PlanStreamInput + * to properly map NameIds. + */ long unmappedId = in.readLong(); return in.mapNameId(unmappedId); } diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/ReferenceAttribute.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/ReferenceAttribute.java index d9311dfa27edd..8bac20e9347bc 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/ReferenceAttribute.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/ReferenceAttribute.java @@ -6,14 +6,25 @@ */ package org.elasticsearch.xpack.esql.core.expression; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.util.PlanStreamInput; + +import java.io.IOException; /** * Attribute based on a reference to an expression. */ public class ReferenceAttribute extends TypedAttribute { + static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Attribute.class, + "ReferenceAttribute", + ReferenceAttribute::new + ); public ReferenceAttribute(Source source, String name, DataType dataType) { this(source, name, dataType, null, Nullability.FALSE, null, false); @@ -31,6 +42,43 @@ public ReferenceAttribute( super(source, name, dataType, qualifier, nullability, id, synthetic); } + @SuppressWarnings("unchecked") + public ReferenceAttribute(StreamInput in) throws IOException { + /* + * The funny casting dance with `` and `(S) in` is required + * because we're in esql-core here and the real PlanStreamInput is in + * esql-proper. And because NamedWriteableRegistry.Entry needs StreamInput, + * not a PlanStreamInput. And we need PlanStreamInput to handle Source + * and NameId. This should become a hard cast when we move everything out + * of esql-core. + */ + this( + Source.readFrom((S) in), + in.readString(), + DataType.readFrom(in), + in.readOptionalString(), + in.readEnum(Nullability.class), + NameId.readFrom((S) in), + in.readBoolean() + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + out.writeString(name()); + dataType().writeTo(out); + out.writeOptionalString(qualifier()); + out.writeEnum(nullable()); + id().writeTo(out); + out.writeBoolean(synthetic()); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override protected Attribute clone( Source source, diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/UnresolvedAttribute.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/UnresolvedAttribute.java index 923a72d311166..87ef37cb84d1f 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/UnresolvedAttribute.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/UnresolvedAttribute.java @@ -6,6 +6,7 @@ */ package org.elasticsearch.xpack.esql.core.expression; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.capabilities.Unresolvable; import org.elasticsearch.xpack.esql.core.capabilities.UnresolvedException; import org.elasticsearch.xpack.esql.core.tree.NodeInfo; @@ -13,12 +14,12 @@ import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.util.CollectionUtils; +import java.io.IOException; import java.util.List; import java.util.Objects; // unfortunately we can't use UnresolvedNamedExpression public class UnresolvedAttribute extends Attribute implements Unresolvable { - private final String unresolvedMsg; private final boolean customMessage; private final Object resolutionMetadata; @@ -50,6 +51,16 @@ public UnresolvedAttribute( this.resolutionMetadata = resolutionMetadata; } + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException("doesn't escape the node"); + } + + @Override + public String getWriteableName() { + throw new UnsupportedOperationException("doesn't escape the node"); + } + @Override protected NodeInfo info() { return NodeInfo.create(this, UnresolvedAttribute::new, name(), qualifier(), id(), unresolvedMsg, resolutionMetadata); diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRules.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRules.java index 4ce5f4ae8652e..12b496e51fa1b 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRules.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/optimizer/OptimizerRules.java @@ -29,10 +29,6 @@ import org.elasticsearch.xpack.esql.core.expression.predicate.logical.Or; import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNotNull; import org.elasticsearch.xpack.esql.core.expression.predicate.nulls.IsNull; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.ArithmeticOperation; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.BinaryComparisonInversible; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.Neg; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.Sub; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.GreaterThan; @@ -49,12 +45,10 @@ import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; import org.elasticsearch.xpack.esql.core.rule.Rule; -import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.DataTypes; import org.elasticsearch.xpack.esql.core.util.CollectionUtils; import org.elasticsearch.xpack.esql.core.util.ReflectionUtils; -import java.time.DateTimeException; import java.time.ZoneId; import java.util.ArrayList; import java.util.Iterator; @@ -66,8 +60,6 @@ import java.util.Set; import java.util.function.BiFunction; -import static java.lang.Math.signum; -import static java.util.Arrays.asList; import static java.util.Collections.emptySet; import static org.elasticsearch.xpack.esql.core.expression.Literal.FALSE; import static org.elasticsearch.xpack.esql.core.expression.Literal.TRUE; @@ -77,12 +69,6 @@ import static org.elasticsearch.xpack.esql.core.expression.predicate.Predicates.splitAnd; import static org.elasticsearch.xpack.esql.core.expression.predicate.Predicates.splitOr; import static org.elasticsearch.xpack.esql.core.expression.predicate.Predicates.subtract; -import static org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation.ADD; -import static org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation.DIV; -import static org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation.MOD; -import static org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation.MUL; -import static org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation.SUB; -import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; import static org.elasticsearch.xpack.esql.core.util.CollectionUtils.combine; public final class OptimizerRules { @@ -1282,216 +1268,6 @@ protected Expression rule(Expression e) { } } - // Simplifies arithmetic expressions with BinaryComparisons and fixed point fields, such as: (int + 2) / 3 > 4 => int > 10 - public static final class SimplifyComparisonsArithmetics extends OptimizerExpressionRule { - BiFunction typesCompatible; - - public SimplifyComparisonsArithmetics(BiFunction typesCompatible) { - super(TransformDirection.UP); - this.typesCompatible = typesCompatible; - } - - @Override - protected Expression rule(BinaryComparison bc) { - // optimize only once the expression has a literal on the right side of the binary comparison - if (bc.right() instanceof Literal) { - if (bc.left() instanceof ArithmeticOperation) { - return simplifyBinaryComparison(bc); - } - if (bc.left() instanceof Neg) { - return foldNegation(bc); - } - } - return bc; - } - - private Expression simplifyBinaryComparison(BinaryComparison comparison) { - ArithmeticOperation operation = (ArithmeticOperation) comparison.left(); - // Use symbol comp: SQL operations aren't available in this package (as dependencies) - String opSymbol = operation.symbol(); - // Modulo can't be simplified. - if (opSymbol.equals(MOD.symbol())) { - return comparison; - } - OperationSimplifier simplification = null; - if (isMulOrDiv(opSymbol)) { - simplification = new MulDivSimplifier(comparison); - } else if (opSymbol.equals(ADD.symbol()) || opSymbol.equals(SUB.symbol())) { - simplification = new AddSubSimplifier(comparison); - } - - return (simplification == null || simplification.isUnsafe(typesCompatible)) ? comparison : simplification.apply(); - } - - private static boolean isMulOrDiv(String opSymbol) { - return opSymbol.equals(MUL.symbol()) || opSymbol.equals(DIV.symbol()); - } - - private static Expression foldNegation(BinaryComparison bc) { - Literal bcLiteral = (Literal) bc.right(); - Expression literalNeg = tryFolding(new Neg(bcLiteral.source(), bcLiteral)); - return literalNeg == null ? bc : bc.reverse().replaceChildren(asList(((Neg) bc.left()).field(), literalNeg)); - } - - private static Expression tryFolding(Expression expression) { - if (expression.foldable()) { - try { - expression = new Literal(expression.source(), expression.fold(), expression.dataType()); - } catch (ArithmeticException | DateTimeException e) { - // null signals that folding would result in an over-/underflow (such as Long.MAX_VALUE+1); the optimisation is skipped. - expression = null; - } - } - return expression; - } - - private abstract static class OperationSimplifier { - final BinaryComparison comparison; - final Literal bcLiteral; - final ArithmeticOperation operation; - final Expression opLeft; - final Expression opRight; - final Literal opLiteral; - - OperationSimplifier(BinaryComparison comparison) { - this.comparison = comparison; - operation = (ArithmeticOperation) comparison.left(); - bcLiteral = (Literal) comparison.right(); - - opLeft = operation.left(); - opRight = operation.right(); - - if (opLeft instanceof Literal) { - opLiteral = (Literal) opLeft; - } else if (opRight instanceof Literal) { - opLiteral = (Literal) opRight; - } else { - opLiteral = null; - } - } - - // can it be quickly fast-tracked that the operation can't be reduced? - final boolean isUnsafe(BiFunction typesCompatible) { - if (opLiteral == null) { - // one of the arithm. operands must be a literal, otherwise the operation wouldn't simplify anything - return true; - } - - // Only operations on fixed point literals are supported, since optimizing float point operations can also change the - // outcome of the filtering: - // x + 1e18 > 1e18::long will yield different results with a field value in [-2^6, 2^6], optimised vs original; - // x * (1 + 1e-15d) > 1 : same with a field value of (1 - 1e-15d) - // so consequently, int fields optimisation requiring FP arithmetic isn't possible either: (x - 1e-15) * (1 + 1e-15) > 1. - if (opLiteral.dataType().isRational() || bcLiteral.dataType().isRational()) { - return true; - } - - // the Literal will be moved to the right of the comparison, but only if data-compatible with what's there - if (typesCompatible.apply(bcLiteral.dataType(), opLiteral.dataType()) == false) { - return true; - } - - return isOpUnsafe(); - } - - final Expression apply() { - // force float point folding for FlP field - Literal bcl = operation.dataType().isRational() - ? Literal.of(bcLiteral, ((Number) bcLiteral.value()).doubleValue()) - : bcLiteral; - - Expression bcRightExpression = ((BinaryComparisonInversible) operation).binaryComparisonInverse() - .create(bcl.source(), bcl, opRight); - bcRightExpression = tryFolding(bcRightExpression); - return bcRightExpression != null - ? postProcess((BinaryComparison) comparison.replaceChildren(List.of(opLeft, bcRightExpression))) - : comparison; - } - - // operation-specific operations: - // - fast-tracking of simplification unsafety - abstract boolean isOpUnsafe(); - - // - post optimisation adjustments - Expression postProcess(BinaryComparison binaryComparison) { - return binaryComparison; - } - } - - private static class AddSubSimplifier extends OperationSimplifier { - - AddSubSimplifier(BinaryComparison comparison) { - super(comparison); - } - - @Override - boolean isOpUnsafe() { - // no ADD/SUB with floating fields - if (operation.dataType().isRational()) { - return true; - } - - if (operation.symbol().equals(SUB.symbol()) && opRight instanceof Literal == false) { // such as: 1 - x > -MAX - // if next simplification step would fail on overflow anyways, skip the optimisation already - return tryFolding(new Sub(EMPTY, opLeft, bcLiteral)) == null; - } - - return false; - } - } - - private static class MulDivSimplifier extends OperationSimplifier { - - private final boolean isDiv; // and not MUL. - private final int opRightSign; // sign of the right operand in: (left) (op) (right) (comp) (literal) - - MulDivSimplifier(BinaryComparison comparison) { - super(comparison); - isDiv = operation.symbol().equals(DIV.symbol()); - opRightSign = sign(opRight); - } - - @Override - boolean isOpUnsafe() { - // Integer divisions are not safe to optimise: x / 5 > 1 <=/=> x > 5 for x in [6, 9]; same for the `==` comp - if (operation.dataType().isInteger() && isDiv) { - return true; - } - - // If current operation is a multiplication, it's inverse will be a division: safe only if outcome is still integral. - if (isDiv == false && opLeft.dataType().isInteger()) { - long opLiteralValue = ((Number) opLiteral.value()).longValue(); - return opLiteralValue == 0 || ((Number) bcLiteral.value()).longValue() % opLiteralValue != 0; - } - - // can't move a 0 in Mul/Div comparisons - return opRightSign == 0; - } - - @Override - Expression postProcess(BinaryComparison binaryComparison) { - // negative multiplication/division changes the direction of the comparison - return opRightSign < 0 ? binaryComparison.reverse() : binaryComparison; - } - - private static int sign(Object obj) { - int sign = 1; - if (obj instanceof Number) { - sign = (int) signum(((Number) obj).doubleValue()); - } else if (obj instanceof Literal) { - sign = sign(((Literal) obj).value()); - } else if (obj instanceof Neg) { - sign = -sign(((Neg) obj).field()); - } else if (obj instanceof ArithmeticOperation operation) { - if (isMulOrDiv(operation.symbol())) { - sign = sign(operation.left()) * sign(operation.right()); - } - } - return sign; - } - } - } - public abstract static class PruneFilters extends OptimizerRule { @Override diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Source.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Source.java index 2129bd8743d0f..e53593e944632 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Source.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/tree/Source.java @@ -34,6 +34,11 @@ public Source(Location location, String text) { } public static Source readFrom(S in) throws IOException { + /* + * The funny typing dance with `` is required we're in esql-core + * here and the real PlanStreamInput is in esql-proper. And we need PlanStreamInput + * to send the query one time. + */ if (in.readBoolean() == false) { return EMPTY; } diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/AbstractEsFieldTypeTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/AbstractEsFieldTypeTests.java index 372bd7ec8066a..a415c529894c3 100644 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/AbstractEsFieldTypeTests.java +++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/AbstractEsFieldTypeTests.java @@ -15,10 +15,22 @@ import java.util.TreeMap; public abstract class AbstractEsFieldTypeTests extends AbstractNamedWriteableTestCase { + public static EsField randomAnyEsField(int maxDepth) { + return switch (between(0, 5)) { + case 0 -> EsFieldTests.randomEsField(maxDepth); + case 1 -> DateEsFieldTests.randomDateEsField(maxDepth); + case 2 -> InvalidMappedFieldTests.randomInvalidMappedField(maxDepth); + case 3 -> KeywordEsFieldTests.randomKeywordEsField(maxDepth); + case 4 -> TextEsFieldTests.randomTextEsField(maxDepth); + case 5 -> UnsupportedEsFieldTests.randomUnsupportedEsField(maxDepth); + default -> throw new IllegalArgumentException(); + }; + } + @Override protected abstract T createTestInstance(); - protected abstract T mutate(T instance) throws IOException; + protected abstract T mutate(T instance); /** * Generate sub-properties. @@ -34,15 +46,7 @@ static Map randomProperties(int maxDepth) { int targetSize = between(1, 5); Map properties = new TreeMap<>(); while (properties.size() < targetSize) { - properties.put(randomAlphaOfLength(properties.size() + 1), switch (between(0, 5)) { - case 0 -> EsFieldTests.randomEsField(maxDepth - 1); - case 1 -> DateEsFieldTests.randomDateEsField(maxDepth - 1); - case 2 -> InvalidMappedFieldTests.randomInvalidMappedField(maxDepth - 1); - case 3 -> KeywordEsFieldTests.randomKeywordEsField(maxDepth - 1); - case 4 -> TextEsFieldTests.randomTextEsField(maxDepth - 1); - case 5 -> UnsupportedEsFieldTests.randomUnsupportedEsField(maxDepth - 1); - default -> throw new IllegalArgumentException(); - }); + properties.put(randomAlphaOfLength(properties.size() + 1), randomAnyEsField(maxDepth - 1)); } return properties; } diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/DateEsFieldTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/DateEsFieldTests.java index c6428034eaae1..dea03ee8a8cdf 100644 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/DateEsFieldTests.java +++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/DateEsFieldTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.core.type; -import java.io.IOException; import java.util.Map; public class DateEsFieldTests extends AbstractEsFieldTypeTests { @@ -21,7 +20,7 @@ protected DateEsField createTestInstance() { } @Override - protected DateEsField mutate(DateEsField instance) throws IOException { + protected DateEsField mutate(DateEsField instance) { String name = instance.getName(); Map properties = instance.getProperties(); boolean aggregatable = instance.isAggregatable(); diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/EsFieldTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/EsFieldTests.java index 2d75def3296d3..75921778d5970 100644 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/EsFieldTests.java +++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/EsFieldTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.core.type; -import java.io.IOException; import java.util.Map; public class EsFieldTests extends AbstractEsFieldTypeTests { @@ -26,7 +25,7 @@ protected EsField createTestInstance() { } @Override - protected EsField mutate(EsField instance) throws IOException { + protected EsField mutate(EsField instance) { String name = instance.getName(); DataType esDataType = instance.getDataType(); Map properties = instance.getProperties(); diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/InvalidMappedFieldTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/InvalidMappedFieldTests.java index 5f96b4a720381..47a99329d0222 100644 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/InvalidMappedFieldTests.java +++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/InvalidMappedFieldTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.core.type; -import java.io.IOException; import java.util.Map; public class InvalidMappedFieldTests extends AbstractEsFieldTypeTests { @@ -24,7 +23,7 @@ protected InvalidMappedField createTestInstance() { } @Override - protected InvalidMappedField mutate(InvalidMappedField instance) throws IOException { + protected InvalidMappedField mutate(InvalidMappedField instance) { String name = instance.getName(); String errorMessage = instance.errorMessage(); Map properties = instance.getProperties(); diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/KeywordEsFieldTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/KeywordEsFieldTests.java index 0b1c866114737..a5d3b8329b2df 100644 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/KeywordEsFieldTests.java +++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/KeywordEsFieldTests.java @@ -9,7 +9,6 @@ import org.elasticsearch.test.ESTestCase; -import java.io.IOException; import java.util.Map; public class KeywordEsFieldTests extends AbstractEsFieldTypeTests { @@ -29,7 +28,7 @@ protected KeywordEsField createTestInstance() { } @Override - protected KeywordEsField mutate(KeywordEsField instance) throws IOException { + protected KeywordEsField mutate(KeywordEsField instance) { String name = instance.getName(); Map properties = instance.getProperties(); boolean hasDocValues = instance.isAggregatable(); diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/TextEsFieldTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/TextEsFieldTests.java index df00bf27c32b9..817dd7cd27094 100644 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/TextEsFieldTests.java +++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/TextEsFieldTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.esql.core.type; -import java.io.IOException; import java.util.Map; public class TextEsFieldTests extends AbstractEsFieldTypeTests { @@ -25,7 +24,7 @@ protected TextEsField createTestInstance() { } @Override - protected TextEsField mutate(TextEsField instance) throws IOException { + protected TextEsField mutate(TextEsField instance) { String name = instance.getName(); Map properties = instance.getProperties(); boolean hasDocValues = instance.isAggregatable(); diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/UnsupportedEsFieldTests.java b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/UnsupportedEsFieldTests.java index fb1ca014a0ee0..e05d8ca10425e 100644 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/UnsupportedEsFieldTests.java +++ b/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/type/UnsupportedEsFieldTests.java @@ -7,11 +7,10 @@ package org.elasticsearch.xpack.esql.core.type; -import java.io.IOException; import java.util.Map; public class UnsupportedEsFieldTests extends AbstractEsFieldTypeTests { - static UnsupportedEsField randomUnsupportedEsField(int maxPropertiesDepth) { + public static UnsupportedEsField randomUnsupportedEsField(int maxPropertiesDepth) { String name = randomAlphaOfLength(4); String originalType = randomAlphaOfLength(5); String inherited = randomBoolean() ? null : randomAlphaOfLength(5); @@ -25,7 +24,7 @@ protected UnsupportedEsField createTestInstance() { } @Override - protected UnsupportedEsField mutate(UnsupportedEsField instance) throws IOException { + protected UnsupportedEsField mutate(UnsupportedEsField instance) { String name = instance.getName(); String originalType = randomAlphaOfLength(5); String inherited = randomBoolean() ? null : randomAlphaOfLength(5); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttribute.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttribute.java index 9c37dab90db54..fe6db916f7a0d 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttribute.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttribute.java @@ -7,6 +7,9 @@ package org.elasticsearch.xpack.esql.expression.function; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.xpack.esql.core.capabilities.Unresolvable; import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; @@ -16,7 +19,9 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.UnsupportedEsField; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import java.io.IOException; import java.util.Objects; /** @@ -26,9 +31,14 @@ * As such the field is marked as unresolved (so the verifier can pick up its usage outside project). */ public final class UnsupportedAttribute extends FieldAttribute implements Unresolvable { + public static final NamedWriteableRegistry.Entry ENTRY = new NamedWriteableRegistry.Entry( + Attribute.class, + "UnsupportedAttribute", + UnsupportedAttribute::new + ); private final String message; - private final boolean hasCustomMessage; + private final boolean hasCustomMessage; // TODO remove me and just use message != null? private static String errorMessage(String name, UnsupportedEsField field) { return "Cannot use field [" + name + "] with unsupported type [" + field.getOriginalType() + "]"; @@ -48,6 +58,30 @@ public UnsupportedAttribute(Source source, String name, UnsupportedEsField field this.message = customMessage == null ? errorMessage(qualifiedName(), field) : customMessage; } + public UnsupportedAttribute(StreamInput in) throws IOException { + this( + Source.readFrom((PlanStreamInput) in), + in.readString(), + new UnsupportedEsField(in), + in.readOptionalString(), + NameId.readFrom((PlanStreamInput) in) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + Source.EMPTY.writeTo(out); + out.writeString(name()); + field().writeTo(out); + out.writeOptionalString(hasCustomMessage ? message : null); + id().writeTo(out); + } + + @Override + public String getWriteableName() { + return ENTRY.name; + } + @Override public boolean resolved() { return false; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java index c0859782937bc..f605f898366e1 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypes.java @@ -29,7 +29,6 @@ import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.Order; import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; import org.elasticsearch.xpack.esql.core.expression.function.scalar.ScalarFunction; @@ -51,7 +50,6 @@ import org.elasticsearch.xpack.esql.core.tree.Source; import org.elasticsearch.xpack.esql.core.type.DataType; import org.elasticsearch.xpack.esql.core.type.EsField; -import org.elasticsearch.xpack.esql.core.type.UnsupportedEsField; import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.expression.function.aggregate.AggregateFunction; import org.elasticsearch.xpack.esql.expression.function.aggregate.Avg; @@ -299,10 +297,10 @@ public static List namedTypeEntries() { of(LogicalPlan.class, Project.class, PlanNamedTypes::writeProject, PlanNamedTypes::readProject), of(LogicalPlan.class, TopN.class, PlanNamedTypes::writeTopN, PlanNamedTypes::readTopN), // Attributes - of(Attribute.class, FieldAttribute.class, PlanNamedTypes::writeFieldAttribute, PlanNamedTypes::readFieldAttribute), - of(Attribute.class, ReferenceAttribute.class, PlanNamedTypes::writeReferenceAttr, PlanNamedTypes::readReferenceAttr), - of(Attribute.class, MetadataAttribute.class, PlanNamedTypes::writeMetadataAttr, PlanNamedTypes::readMetadataAttr), - of(Attribute.class, UnsupportedAttribute.class, PlanNamedTypes::writeUnsupportedAttr, PlanNamedTypes::readUnsupportedAttr), + of(NamedExpression.class, FieldAttribute.class, (o, a) -> a.writeTo(o), FieldAttribute::new), + of(NamedExpression.class, ReferenceAttribute.class, (o, a) -> a.writeTo(o), ReferenceAttribute::new), + of(NamedExpression.class, MetadataAttribute.class, (o, a) -> a.writeTo(o), MetadataAttribute::new), + of(NamedExpression.class, UnsupportedAttribute.class, (o, a) -> a.writeTo(o), UnsupportedAttribute::new), // NamedExpressions of(NamedExpression.class, Alias.class, PlanNamedTypes::writeAlias, PlanNamedTypes::readAlias), // BinaryComparison @@ -473,7 +471,7 @@ static DissectExec readDissectExec(PlanStreamInput in) throws IOException { in.readPhysicalPlanNode(), in.readExpression(), readDissectParser(in), - readAttributes(in) + in.readNamedWriteableCollectionAsList(Attribute.class) ); } @@ -482,7 +480,7 @@ static void writeDissectExec(PlanStreamOutput out, DissectExec dissectExec) thro out.writePhysicalPlanNode(dissectExec.child()); out.writeExpression(dissectExec.inputExpression()); writeDissectParser(out, dissectExec.parser()); - writeAttributes(out, dissectExec.extractedFields()); + out.writeNamedWriteableCollection(dissectExec.extractedFields()); } static EsQueryExec readEsQueryExec(PlanStreamInput in) throws IOException { @@ -490,7 +488,7 @@ static EsQueryExec readEsQueryExec(PlanStreamInput in) throws IOException { Source.readFrom(in), readEsIndex(in), readIndexMode(in), - readAttributes(in), + in.readNamedWriteableCollectionAsList(Attribute.class), in.readOptionalNamedWriteable(QueryBuilder.class), in.readOptionalNamed(Expression.class), in.readOptionalCollectionAsList(readerFromPlanReader(PlanNamedTypes::readFieldSort)), @@ -503,7 +501,7 @@ static void writeEsQueryExec(PlanStreamOutput out, EsQueryExec esQueryExec) thro Source.EMPTY.writeTo(out); writeEsIndex(out, esQueryExec.index()); writeIndexMode(out, esQueryExec.indexMode()); - writeAttributes(out, esQueryExec.output()); + out.writeNamedWriteableCollection(esQueryExec.output()); out.writeOptionalNamedWriteable(esQueryExec.query()); out.writeOptionalExpression(esQueryExec.limit()); out.writeOptionalCollection(esQueryExec.sorts(), writerFromPlanWriter(PlanNamedTypes::writeFieldSort)); @@ -514,7 +512,7 @@ static EsSourceExec readEsSourceExec(PlanStreamInput in) throws IOException { return new EsSourceExec( Source.readFrom(in), readEsIndex(in), - readAttributes(in), + in.readNamedWriteableCollectionAsList(Attribute.class), in.readOptionalNamedWriteable(QueryBuilder.class), readIndexMode(in) ); @@ -523,7 +521,7 @@ static EsSourceExec readEsSourceExec(PlanStreamInput in) throws IOException { static void writeEsSourceExec(PlanStreamOutput out, EsSourceExec esSourceExec) throws IOException { Source.EMPTY.writeTo(out); writeEsIndex(out, esSourceExec.index()); - writeAttributes(out, esSourceExec.output()); + out.writeNamedWriteableCollection(esSourceExec.output()); out.writeOptionalNamedWriteable(esSourceExec.query()); writeIndexMode(out, esSourceExec.indexMode()); } @@ -613,44 +611,54 @@ static void writeEnrichExec(PlanStreamOutput out, EnrichExec enrich) throws IOEx } static ExchangeExec readExchangeExec(PlanStreamInput in) throws IOException { - return new ExchangeExec(Source.readFrom(in), readAttributes(in), in.readBoolean(), in.readPhysicalPlanNode()); + return new ExchangeExec( + Source.readFrom(in), + in.readNamedWriteableCollectionAsList(Attribute.class), + in.readBoolean(), + in.readPhysicalPlanNode() + ); } static void writeExchangeExec(PlanStreamOutput out, ExchangeExec exchangeExec) throws IOException { Source.EMPTY.writeTo(out); - writeAttributes(out, exchangeExec.output()); + out.writeNamedWriteableCollection(exchangeExec.output()); out.writeBoolean(exchangeExec.isInBetweenAggs()); out.writePhysicalPlanNode(exchangeExec.child()); } static ExchangeSinkExec readExchangeSinkExec(PlanStreamInput in) throws IOException { - return new ExchangeSinkExec(Source.readFrom(in), readAttributes(in), in.readBoolean(), in.readPhysicalPlanNode()); + return new ExchangeSinkExec( + Source.readFrom(in), + in.readNamedWriteableCollectionAsList(Attribute.class), + in.readBoolean(), + in.readPhysicalPlanNode() + ); } static void writeExchangeSinkExec(PlanStreamOutput out, ExchangeSinkExec exchangeSinkExec) throws IOException { Source.EMPTY.writeTo(out); - writeAttributes(out, exchangeSinkExec.output()); + out.writeNamedWriteableCollection(exchangeSinkExec.output()); out.writeBoolean(exchangeSinkExec.isIntermediateAgg()); out.writePhysicalPlanNode(exchangeSinkExec.child()); } static ExchangeSourceExec readExchangeSourceExec(PlanStreamInput in) throws IOException { - return new ExchangeSourceExec(Source.readFrom(in), readAttributes(in), in.readBoolean()); + return new ExchangeSourceExec(Source.readFrom(in), in.readNamedWriteableCollectionAsList(Attribute.class), in.readBoolean()); } static void writeExchangeSourceExec(PlanStreamOutput out, ExchangeSourceExec exchangeSourceExec) throws IOException { - writeAttributes(out, exchangeSourceExec.output()); + out.writeNamedWriteableCollection(exchangeSourceExec.output()); out.writeBoolean(exchangeSourceExec.isIntermediateAgg()); } static FieldExtractExec readFieldExtractExec(PlanStreamInput in) throws IOException { - return new FieldExtractExec(Source.readFrom(in), in.readPhysicalPlanNode(), readAttributes(in)); + return new FieldExtractExec(Source.readFrom(in), in.readPhysicalPlanNode(), in.readNamedWriteableCollectionAsList(Attribute.class)); } static void writeFieldExtractExec(PlanStreamOutput out, FieldExtractExec fieldExtractExec) throws IOException { Source.EMPTY.writeTo(out); out.writePhysicalPlanNode(fieldExtractExec.child()); - writeAttributes(out, fieldExtractExec.attributesToExtract()); + out.writeNamedWriteableCollection(fieldExtractExec.attributesToExtract()); } static FilterExec readFilterExec(PlanStreamInput in) throws IOException { @@ -690,7 +698,7 @@ static GrokExec readGrokExec(PlanStreamInput in) throws IOException { in.readPhysicalPlanNode(), in.readExpression(), Grok.pattern(source, in.readString()), - readAttributes(in) + in.readNamedWriteableCollectionAsList(Attribute.class) ); } @@ -699,7 +707,7 @@ static void writeGrokExec(PlanStreamOutput out, GrokExec grokExec) throws IOExce out.writePhysicalPlanNode(grokExec.child()); out.writeExpression(grokExec.inputExpression()); out.writeString(grokExec.pattern().pattern()); - writeAttributes(out, grokExec.extractedFields()); + out.writeNamedWriteableCollection(grokExec.extractedFields()); } static LimitExec readLimitExec(PlanStreamInput in) throws IOException { @@ -713,14 +721,19 @@ static void writeLimitExec(PlanStreamOutput out, LimitExec limitExec) throws IOE } static MvExpandExec readMvExpandExec(PlanStreamInput in) throws IOException { - return new MvExpandExec(Source.readFrom(in), in.readPhysicalPlanNode(), in.readNamedExpression(), in.readAttribute()); + return new MvExpandExec( + Source.readFrom(in), + in.readPhysicalPlanNode(), + in.readNamedExpression(), + in.readNamedWriteable(Attribute.class) + ); } static void writeMvExpandExec(PlanStreamOutput out, MvExpandExec mvExpandExec) throws IOException { Source.EMPTY.writeTo(out); out.writePhysicalPlanNode(mvExpandExec.child()); out.writeNamedExpression(mvExpandExec.target()); - out.writeAttribute(mvExpandExec.expanded()); + out.writeNamedWriteable(mvExpandExec.expanded()); } static OrderExec readOrderExec(PlanStreamInput in) throws IOException { @@ -759,12 +772,16 @@ static void writeRowExec(PlanStreamOutput out, RowExec rowExec) throws IOExcepti @SuppressWarnings("unchecked") static ShowExec readShowExec(PlanStreamInput in) throws IOException { - return new ShowExec(Source.readFrom(in), readAttributes(in), (List>) in.readGenericValue()); + return new ShowExec( + Source.readFrom(in), + in.readNamedWriteableCollectionAsList(Attribute.class), + (List>) in.readGenericValue() + ); } static void writeShowExec(PlanStreamOutput out, ShowExec showExec) throws IOException { Source.EMPTY.writeTo(out); - writeAttributes(out, showExec.output()); + out.writeNamedWriteableCollection(showExec.output()); out.writeGenericValue(showExec.values()); } @@ -804,7 +821,13 @@ static void writeAggregate(PlanStreamOutput out, Aggregate aggregate) throws IOE } static Dissect readDissect(PlanStreamInput in) throws IOException { - return new Dissect(Source.readFrom(in), in.readLogicalPlanNode(), in.readExpression(), readDissectParser(in), readAttributes(in)); + return new Dissect( + Source.readFrom(in), + in.readLogicalPlanNode(), + in.readExpression(), + readDissectParser(in), + in.readNamedWriteableCollectionAsList(Attribute.class) + ); } static void writeDissect(PlanStreamOutput out, Dissect dissect) throws IOException { @@ -812,13 +835,13 @@ static void writeDissect(PlanStreamOutput out, Dissect dissect) throws IOExcepti out.writeLogicalPlanNode(dissect.child()); out.writeExpression(dissect.input()); writeDissectParser(out, dissect.parser()); - writeAttributes(out, dissect.extractedFields()); + out.writeNamedWriteableCollection(dissect.extractedFields()); } static EsRelation readEsRelation(PlanStreamInput in) throws IOException { Source source = Source.readFrom(in); EsIndex esIndex = readEsIndex(in); - List attributes = readAttributes(in); + List attributes = in.readNamedWriteableCollectionAsList(Attribute.class); if (supportingEsSourceOptions(in.getTransportVersion())) { readEsSourceOptions(in); // consume optional strings sent by remote } @@ -831,7 +854,7 @@ static void writeEsRelation(PlanStreamOutput out, EsRelation relation) throws IO assert relation.children().size() == 0; Source.EMPTY.writeTo(out); writeEsIndex(out, relation.index()); - writeAttributes(out, relation.output()); + out.writeNamedWriteableCollection(relation.output()); if (supportingEsSourceOptions(out.getTransportVersion())) { writeEsSourceOptions(out); // write (null) string fillers expected by remote } @@ -953,7 +976,7 @@ static Grok readGrok(PlanStreamInput in) throws IOException { in.readLogicalPlanNode(), in.readExpression(), Grok.pattern(source, in.readString()), - readAttributes(in) + in.readNamedWriteableCollectionAsList(Attribute.class) ); } @@ -962,7 +985,7 @@ static void writeGrok(PlanStreamOutput out, Grok grok) throws IOException { out.writeLogicalPlanNode(grok.child()); out.writeExpression(grok.input()); out.writeString(grok.parser().pattern()); - writeAttributes(out, grok.extractedFields()); + out.writeNamedWriteableCollection(grok.extractedFields()); } static Limit readLimit(PlanStreamInput in) throws IOException { @@ -976,14 +999,19 @@ static void writeLimit(PlanStreamOutput out, Limit limit) throws IOException { } static MvExpand readMvExpand(PlanStreamInput in) throws IOException { - return new MvExpand(Source.readFrom(in), in.readLogicalPlanNode(), in.readNamedExpression(), in.readAttribute()); + return new MvExpand( + Source.readFrom(in), + in.readLogicalPlanNode(), + in.readNamedExpression(), + in.readNamedWriteable(Attribute.class) + ); } static void writeMvExpand(PlanStreamOutput out, MvExpand mvExpand) throws IOException { Source.EMPTY.writeTo(out); out.writeLogicalPlanNode(mvExpand.child()); out.writeNamedExpression(mvExpand.target()); - out.writeAttribute(mvExpand.expanded()); + out.writeNamedWriteable(mvExpand.expanded()); } static OrderBy readOrderBy(PlanStreamInput in) throws IOException { @@ -1030,14 +1058,6 @@ static void writeTopN(PlanStreamOutput out, TopN topN) throws IOException { // -- Attributes // - private static List readAttributes(PlanStreamInput in) throws IOException { - return in.readCollectionAsList(readerFromPlanReader(PlanStreamInput::readAttribute)); - } - - static void writeAttributes(PlanStreamOutput out, List attributes) throws IOException { - out.writeCollection(attributes, writerFromPlanWriter(PlanStreamOutput::writeAttribute)); - } - private static List readNamedExpressions(PlanStreamInput in) throws IOException { return in.readCollectionAsList(readerFromPlanReader(PlanStreamInput::readNamedExpression)); } @@ -1054,96 +1074,6 @@ static void writeAliases(PlanStreamOutput out, List aliases) throws IOExc out.writeCollection(aliases, writerFromPlanWriter(PlanNamedTypes::writeAlias)); } - static FieldAttribute readFieldAttribute(PlanStreamInput in) throws IOException { - return new FieldAttribute( - Source.readFrom(in), - in.readOptionalWithReader(PlanNamedTypes::readFieldAttribute), - in.readString(), - DataType.readFrom(in), - in.readNamedWriteable(EsField.class), - in.readOptionalString(), - in.readEnum(Nullability.class), - NameId.readFrom(in), - in.readBoolean() - ); - } - - static void writeFieldAttribute(PlanStreamOutput out, FieldAttribute fieldAttribute) throws IOException { - Source.EMPTY.writeTo(out); - out.writeOptionalWriteable(fieldAttribute.parent() == null ? null : o -> writeFieldAttribute(out, fieldAttribute.parent())); - out.writeString(fieldAttribute.name()); - out.writeString(fieldAttribute.dataType().typeName()); - out.writeNamedWriteable(fieldAttribute.field()); - out.writeOptionalString(fieldAttribute.qualifier()); - out.writeEnum(fieldAttribute.nullable()); - fieldAttribute.id().writeTo(out); - out.writeBoolean(fieldAttribute.synthetic()); - } - - static ReferenceAttribute readReferenceAttr(PlanStreamInput in) throws IOException { - return new ReferenceAttribute( - Source.readFrom(in), - in.readString(), - DataType.readFrom(in), - in.readOptionalString(), - in.readEnum(Nullability.class), - NameId.readFrom(in), - in.readBoolean() - ); - } - - static void writeReferenceAttr(PlanStreamOutput out, ReferenceAttribute referenceAttribute) throws IOException { - Source.EMPTY.writeTo(out); - out.writeString(referenceAttribute.name()); - out.writeString(referenceAttribute.dataType().typeName()); - out.writeOptionalString(referenceAttribute.qualifier()); - out.writeEnum(referenceAttribute.nullable()); - referenceAttribute.id().writeTo(out); - out.writeBoolean(referenceAttribute.synthetic()); - } - - static MetadataAttribute readMetadataAttr(PlanStreamInput in) throws IOException { - return new MetadataAttribute( - Source.readFrom(in), - in.readString(), - DataType.readFrom(in), - in.readOptionalString(), - in.readEnum(Nullability.class), - NameId.readFrom(in), - in.readBoolean(), - in.readBoolean() - ); - } - - static void writeMetadataAttr(PlanStreamOutput out, MetadataAttribute metadataAttribute) throws IOException { - Source.EMPTY.writeTo(out); - out.writeString(metadataAttribute.name()); - out.writeString(metadataAttribute.dataType().typeName()); - out.writeOptionalString(metadataAttribute.qualifier()); - out.writeEnum(metadataAttribute.nullable()); - metadataAttribute.id().writeTo(out); - out.writeBoolean(metadataAttribute.synthetic()); - out.writeBoolean(metadataAttribute.searchable()); - } - - static UnsupportedAttribute readUnsupportedAttr(PlanStreamInput in) throws IOException { - return new UnsupportedAttribute( - Source.readFrom(in), - in.readString(), - new UnsupportedEsField(in), - in.readOptionalString(), - NameId.readFrom(in) - ); - } - - static void writeUnsupportedAttr(PlanStreamOutput out, UnsupportedAttribute unsupportedAttribute) throws IOException { - Source.EMPTY.writeTo(out); - out.writeString(unsupportedAttribute.name()); - unsupportedAttribute.field().writeTo(out); - out.writeOptionalString(unsupportedAttribute.hasCustomMessage() ? unsupportedAttribute.unresolvedMessage() : null); - unsupportedAttribute.id().writeTo(out); - } - // -- BinaryComparison static EsqlBinaryComparison readBinComparison(PlanStreamInput in, String name) throws IOException { @@ -1840,14 +1770,14 @@ static void writeOrder(PlanStreamOutput out, Order order) throws IOException { static EsQueryExec.FieldSort readFieldSort(PlanStreamInput in) throws IOException { return new EsQueryExec.FieldSort( - readFieldAttribute(in), + new FieldAttribute(in), in.readEnum(Order.OrderDirection.class), in.readEnum(Order.NullsPosition.class) ); } static void writeFieldSort(PlanStreamOutput out, EsQueryExec.FieldSort fieldSort) throws IOException { - writeFieldAttribute(out, fieldSort.field()); + fieldSort.field().writeTo(out); out.writeEnum(fieldSort.direction()); out.writeEnum(fieldSort.nulls()); } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java index 73a7ba6549dc8..e7f1fbd6e1460 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamInput.java @@ -12,7 +12,6 @@ import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.util.BigArrays; import org.elasticsearch.compute.data.Block; import org.elasticsearch.compute.data.BlockFactory; @@ -24,8 +23,6 @@ import org.elasticsearch.compute.data.LongBigArrayBlock; import org.elasticsearch.core.Releasables; import org.elasticsearch.xpack.esql.Column; -import org.elasticsearch.xpack.esql.core.expression.Attribute; -import org.elasticsearch.xpack.esql.core.expression.AttributeSet; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.NameId; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; @@ -36,9 +33,7 @@ import org.elasticsearch.xpack.esql.session.EsqlConfiguration; import java.io.IOException; -import java.util.Collection; import java.util.HashMap; -import java.util.HashSet; import java.util.Map; import java.util.function.LongFunction; @@ -106,10 +101,6 @@ public NamedExpression readNamedExpression() throws IOException { return readNamed(NamedExpression.class); } - public Attribute readAttribute() throws IOException { - return readNamed(Attribute.class); - } - public T readNamed(Class type) throws IOException { String name = readString(); @SuppressWarnings("unchecked") @@ -145,18 +136,6 @@ public T readOptionalWithReader(PlanReader reader) throws IOException { } } - public AttributeSet readAttributeSet(Writeable.Reader reader) throws IOException { - int count = readArraySize(); - if (count == 0) { - return new AttributeSet(); - } - Collection builder = new HashSet<>(); - for (int i = 0; i < count; i++) { - builder.add(reader.read(this)); - } - return new AttributeSet(builder); - } - public EsqlConfiguration configuration() throws IOException { return configuration; } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java index 351918699aac4..05dc7ab919868 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/io/stream/PlanStreamOutput.java @@ -19,7 +19,6 @@ import org.elasticsearch.compute.data.LongBigArrayBlock; import org.elasticsearch.core.Nullable; import org.elasticsearch.xpack.esql.Column; -import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; @@ -103,10 +102,6 @@ public void writeNamedExpression(NamedExpression namedExpression) throws IOExcep writeNamed(NamedExpression.class, namedExpression); } - public void writeAttribute(Attribute attribute) throws IOException { - writeNamed(Attribute.class, attribute); - } - public void writeOptionalExpression(Expression expression) throws IOException { if (expression == null) { writeBoolean(false); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java index d97d54eb884f3..7beb3aca05e74 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizer.java @@ -38,7 +38,6 @@ import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.LiteralsOnTheRight; import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.PruneLiteralsInOrderBy; import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.SetAsOptimized; -import org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.SimplifyComparisonsArithmetics; import org.elasticsearch.xpack.esql.core.plan.logical.Filter; import org.elasticsearch.xpack.esql.core.plan.logical.Limit; import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; @@ -62,6 +61,7 @@ import org.elasticsearch.xpack.esql.expression.function.scalar.spatial.SpatialRelatesFunction; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.Equals; import org.elasticsearch.xpack.esql.expression.predicate.operator.comparison.In; +import org.elasticsearch.xpack.esql.optimizer.rules.SimplifyComparisonsArithmetics; import org.elasticsearch.xpack.esql.plan.logical.Aggregate; import org.elasticsearch.xpack.esql.plan.logical.Enrich; import org.elasticsearch.xpack.esql.plan.logical.EsRelation; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SimplifyComparisonsArithmetics.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SimplifyComparisonsArithmetics.java new file mode 100644 index 0000000000000..9a7ee0a587335 --- /dev/null +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/SimplifyComparisonsArithmetics.java @@ -0,0 +1,244 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.optimizer.rules; + +import org.elasticsearch.xpack.esql.core.expression.Expression; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.ArithmeticOperation; +import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.BinaryComparisonInversible; +import org.elasticsearch.xpack.esql.core.expression.predicate.operator.comparison.BinaryComparison; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.DataTypes; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Neg; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Sub; + +import java.time.DateTimeException; +import java.util.List; +import java.util.function.BiFunction; + +import static java.lang.Math.signum; +import static java.util.Arrays.asList; +import static org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation.ADD; +import static org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation.DIV; +import static org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation.MOD; +import static org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation.MUL; +import static org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.DefaultBinaryArithmeticOperation.SUB; +import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; + +/** + * Simplifies arithmetic expressions with BinaryComparisons and fixed point fields, such as: (int + 2) / 3 > 4 => int > 10 + */ +public final class SimplifyComparisonsArithmetics extends + org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.OptimizerExpressionRule { + BiFunction typesCompatible; + + public SimplifyComparisonsArithmetics(BiFunction typesCompatible) { + super(org.elasticsearch.xpack.esql.core.optimizer.OptimizerRules.TransformDirection.UP); + this.typesCompatible = typesCompatible; + } + + @Override + protected Expression rule(BinaryComparison bc) { + // optimize only once the expression has a literal on the right side of the binary comparison + if (bc.right() instanceof Literal) { + if (bc.left() instanceof ArithmeticOperation) { + return simplifyBinaryComparison(bc); + } + if (bc.left() instanceof Neg) { + return foldNegation(bc); + } + } + return bc; + } + + private Expression simplifyBinaryComparison(BinaryComparison comparison) { + ArithmeticOperation operation = (ArithmeticOperation) comparison.left(); + // Use symbol comp: SQL operations aren't available in this package (as dependencies) + String opSymbol = operation.symbol(); + // Modulo can't be simplified. + if (opSymbol.equals(MOD.symbol())) { + return comparison; + } + OperationSimplifier simplification = null; + if (isMulOrDiv(opSymbol)) { + simplification = new MulDivSimplifier(comparison); + } else if (opSymbol.equals(ADD.symbol()) || opSymbol.equals(SUB.symbol())) { + simplification = new AddSubSimplifier(comparison); + } + + return (simplification == null || simplification.isUnsafe(typesCompatible)) ? comparison : simplification.apply(); + } + + private static boolean isMulOrDiv(String opSymbol) { + return opSymbol.equals(MUL.symbol()) || opSymbol.equals(DIV.symbol()); + } + + private static Expression foldNegation(BinaryComparison bc) { + Literal bcLiteral = (Literal) bc.right(); + Expression literalNeg = tryFolding(new Neg(bcLiteral.source(), bcLiteral)); + return literalNeg == null ? bc : bc.reverse().replaceChildren(asList(((Neg) bc.left()).field(), literalNeg)); + } + + private static Expression tryFolding(Expression expression) { + if (expression.foldable()) { + try { + expression = new Literal(expression.source(), expression.fold(), expression.dataType()); + } catch (ArithmeticException | DateTimeException e) { + // null signals that folding would result in an over-/underflow (such as Long.MAX_VALUE+1); the optimisation is skipped. + expression = null; + } + } + return expression; + } + + private abstract static class OperationSimplifier { + final BinaryComparison comparison; + final Literal bcLiteral; + final ArithmeticOperation operation; + final Expression opLeft; + final Expression opRight; + final Literal opLiteral; + + OperationSimplifier(BinaryComparison comparison) { + this.comparison = comparison; + operation = (ArithmeticOperation) comparison.left(); + bcLiteral = (Literal) comparison.right(); + + opLeft = operation.left(); + opRight = operation.right(); + + if (opLeft instanceof Literal) { + opLiteral = (Literal) opLeft; + } else if (opRight instanceof Literal) { + opLiteral = (Literal) opRight; + } else { + opLiteral = null; + } + } + + // can it be quickly fast-tracked that the operation can't be reduced? + final boolean isUnsafe(BiFunction typesCompatible) { + if (opLiteral == null) { + // one of the arithm. operands must be a literal, otherwise the operation wouldn't simplify anything + return true; + } + + // Only operations on fixed point literals are supported, since optimizing float point operations can also change the + // outcome of the filtering: + // x + 1e18 > 1e18::long will yield different results with a field value in [-2^6, 2^6], optimised vs original; + // x * (1 + 1e-15d) > 1 : same with a field value of (1 - 1e-15d) + // so consequently, int fields optimisation requiring FP arithmetic isn't possible either: (x - 1e-15) * (1 + 1e-15) > 1. + if (opLiteral.dataType().isRational() || bcLiteral.dataType().isRational()) { + return true; + } + + // the Literal will be moved to the right of the comparison, but only if data-compatible with what's there + if (typesCompatible.apply(bcLiteral.dataType(), opLiteral.dataType()) == false) { + return true; + } + + return isOpUnsafe(); + } + + final Expression apply() { + // force float point folding for FlP field + Literal bcl = operation.dataType().isRational() + ? new Literal(bcLiteral.source(), ((Number) bcLiteral.value()).doubleValue(), DataTypes.DOUBLE) + : bcLiteral; + + Expression bcRightExpression = ((BinaryComparisonInversible) operation).binaryComparisonInverse() + .create(bcl.source(), bcl, opRight); + bcRightExpression = tryFolding(bcRightExpression); + return bcRightExpression != null + ? postProcess((BinaryComparison) comparison.replaceChildren(List.of(opLeft, bcRightExpression))) + : comparison; + } + + // operation-specific operations: + // - fast-tracking of simplification unsafety + abstract boolean isOpUnsafe(); + + // - post optimisation adjustments + Expression postProcess(BinaryComparison binaryComparison) { + return binaryComparison; + } + } + + private static class AddSubSimplifier extends OperationSimplifier { + + AddSubSimplifier(BinaryComparison comparison) { + super(comparison); + } + + @Override + boolean isOpUnsafe() { + // no ADD/SUB with floating fields + if (operation.dataType().isRational()) { + return true; + } + + if (operation.symbol().equals(SUB.symbol()) && opRight instanceof Literal == false) { // such as: 1 - x > -MAX + // if next simplification step would fail on overflow anyways, skip the optimisation already + return tryFolding(new Sub(EMPTY, opLeft, bcLiteral)) == null; + } + + return false; + } + } + + private static class MulDivSimplifier extends OperationSimplifier { + + private final boolean isDiv; // and not MUL. + private final int opRightSign; // sign of the right operand in: (left) (op) (right) (comp) (literal) + + MulDivSimplifier(BinaryComparison comparison) { + super(comparison); + isDiv = operation.symbol().equals(DIV.symbol()); + opRightSign = sign(opRight); + } + + @Override + boolean isOpUnsafe() { + // Integer divisions are not safe to optimise: x / 5 > 1 <=/=> x > 5 for x in [6, 9]; same for the `==` comp + if (operation.dataType().isInteger() && isDiv) { + return true; + } + + // If current operation is a multiplication, it's inverse will be a division: safe only if outcome is still integral. + if (isDiv == false && opLeft.dataType().isInteger()) { + long opLiteralValue = ((Number) opLiteral.value()).longValue(); + return opLiteralValue == 0 || ((Number) bcLiteral.value()).longValue() % opLiteralValue != 0; + } + + // can't move a 0 in Mul/Div comparisons + return opRightSign == 0; + } + + @Override + Expression postProcess(BinaryComparison binaryComparison) { + // negative multiplication/division changes the direction of the comparison + return opRightSign < 0 ? binaryComparison.reverse() : binaryComparison; + } + + private static int sign(Object obj) { + int sign = 1; + if (obj instanceof Number) { + sign = (int) signum(((Number) obj).doubleValue()); + } else if (obj instanceof Literal) { + sign = sign(((Literal) obj).value()); + } else if (obj instanceof Neg) { + sign = -sign(((Neg) obj).field()); + } else if (obj instanceof ArithmeticOperation operation) { + if (isMulOrDiv(operation.symbol())) { + sign = sign(operation.left()) * sign(operation.right()); + } + } + return sign; + } + } +} diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java index f3b06f5629524..6059b61031d1e 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/EsqlPlugin.java @@ -55,10 +55,12 @@ import org.elasticsearch.xpack.esql.action.RestEsqlDeleteAsyncResultAction; import org.elasticsearch.xpack.esql.action.RestEsqlGetAsyncResultAction; import org.elasticsearch.xpack.esql.action.RestEsqlQueryAction; +import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.index.IndexResolver; import org.elasticsearch.xpack.esql.core.type.EsField; import org.elasticsearch.xpack.esql.enrich.EnrichLookupOperator; import org.elasticsearch.xpack.esql.execution.PlanExecutor; +import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.querydsl.query.SingleValueQuery; import org.elasticsearch.xpack.esql.session.EsqlIndexResolver; import org.elasticsearch.xpack.esql.type.EsqlDataTypeRegistry; @@ -191,6 +193,8 @@ public List getNamedWriteables() { entries.add(EnrichLookupOperator.Status.ENTRY); entries.addAll(Block.getNamedWriteables()); entries.addAll(EsField.getNamedWriteables()); + entries.addAll(Attribute.getNamedWriteables()); + entries.add(UnsupportedAttribute.ENTRY); // TODO combine with above once these are in the same project return entries; } diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java index 3828b1b290c8f..6ef33b7ae5eb8 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/SerializationTestUtils.java @@ -23,9 +23,11 @@ import org.elasticsearch.index.query.TermsQueryBuilder; import org.elasticsearch.index.query.WildcardQueryBuilder; import org.elasticsearch.test.EqualsHashCodeTestUtils; +import org.elasticsearch.xpack.esql.core.expression.Attribute; import org.elasticsearch.xpack.esql.core.expression.Expression; import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.type.EsField; +import org.elasticsearch.xpack.esql.expression.function.UnsupportedAttribute; import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; import org.elasticsearch.xpack.esql.io.stream.PlanStreamOutput; @@ -114,6 +116,8 @@ public static NamedWriteableRegistry writableRegistry() { entries.add(new NamedWriteableRegistry.Entry(QueryBuilder.class, ExistsQueryBuilder.NAME, ExistsQueryBuilder::new)); entries.add(SingleValueQuery.ENTRY); entries.addAll(EsField.getNamedWriteables()); + entries.addAll(Attribute.getNamedWriteables()); + entries.add(UnsupportedAttribute.ENTRY); return new NamedWriteableRegistry(entries); } } diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/TyperResolutionTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/TyperResolutionTests.java similarity index 80% rename from x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/TyperResolutionTests.java rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/TyperResolutionTests.java index 213c29040a4b2..c2031cb0a8efa 100644 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/expression/TyperResolutionTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/TyperResolutionTests.java @@ -5,12 +5,13 @@ * 2.0. */ -package org.elasticsearch.xpack.esql.core.expression; +package org.elasticsearch.xpack.esql.expression; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.TestUtils; import org.elasticsearch.xpack.esql.core.expression.Expression.TypeResolution; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.Mul; +import org.elasticsearch.xpack.esql.core.expression.Literal; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Mul; import static org.elasticsearch.xpack.esql.core.tree.Source.EMPTY; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAttributeTestCase.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAttributeTestCase.java new file mode 100644 index 0000000000000..17dcab2048eb1 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/AbstractAttributeTestCase.java @@ -0,0 +1,123 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function; + +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xpack.esql.core.expression.Attribute; +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.EsField; +import org.elasticsearch.xpack.esql.io.stream.PlanNameRegistry; +import org.elasticsearch.xpack.esql.io.stream.PlanStreamInput; +import org.elasticsearch.xpack.esql.session.EsqlConfigurationSerializationTests; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +import static org.hamcrest.Matchers.sameInstance; + +public abstract class AbstractAttributeTestCase extends AbstractWireSerializingTestCase< + AbstractAttributeTestCase.ExtraAttribute> { + protected abstract T create(); + + protected abstract T mutate(T instance); + + @Override + protected final ExtraAttribute createTestInstance() { + return new ExtraAttribute(create()); + } + + @Override + @SuppressWarnings("unchecked") + protected final ExtraAttribute mutateInstance(ExtraAttribute instance) { + return new ExtraAttribute(mutate((T) instance.a)); + } + + @Override + protected final NamedWriteableRegistry getNamedWriteableRegistry() { + List entries = new ArrayList<>(Attribute.getNamedWriteables()); + entries.add(UnsupportedAttribute.ENTRY); + entries.addAll(EsField.getNamedWriteables()); + return new NamedWriteableRegistry(entries); + } + + @Override + protected final Writeable.Reader instanceReader() { + return ExtraAttribute::new; + } + + /** + * Adds extra equality comparisons needed for testing round trips of {@link Attribute}. + */ + public static class ExtraAttribute implements Writeable { + private final Attribute a; + + ExtraAttribute(Attribute a) { + this.a = a; + assertThat(a.source(), sameInstance(Source.EMPTY)); + } + + ExtraAttribute(StreamInput in) throws IOException { + PlanStreamInput ps = new PlanStreamInput( + in, + PlanNameRegistry.INSTANCE, + in.namedWriteableRegistry(), + EsqlConfigurationSerializationTests.randomConfiguration("", Map.of()) + ); + ps.setTransportVersion(in.getTransportVersion()); + a = ps.readNamedWriteable(Attribute.class); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeNamedWriteable(a); + } + + @Override + public boolean equals(Object obj) { + if (obj == null) { + return a.equals(null); + } + if (obj.getClass() != getClass()) { + return a.equals(obj); + } + ExtraAttribute other = (ExtraAttribute) obj; + if (false == a.equals(other.a)) { + return false; + } + if (a instanceof FieldAttribute fa && false == fa.field().equals(((FieldAttribute) other.a).field())) { + return false; + } + return a.source() == Source.EMPTY; + } + + @Override + public int hashCode() { + if (a instanceof FieldAttribute fa) { + return Objects.hash(a, a.source(), fa.field()); + } + return Objects.hash(a, a.source()); + } + + @Override + public String toString() { + StringBuilder b = new StringBuilder(a.toString()); + if (a instanceof FieldAttribute fa) { + b.append(", field=").append(fa.field()); + } + return b.toString(); + } + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/FieldAttributeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/FieldAttributeTests.java new file mode 100644 index 0000000000000..ee542232aa30b --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/FieldAttributeTests.java @@ -0,0 +1,58 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function; + +import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; +import org.elasticsearch.xpack.esql.core.expression.NameId; +import org.elasticsearch.xpack.esql.core.expression.Nullability; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.AbstractEsFieldTypeTests; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.DataTypes; +import org.elasticsearch.xpack.esql.core.type.EsField; + +public class FieldAttributeTests extends AbstractAttributeTestCase { + static FieldAttribute createFieldAttribute(int maxDepth) { + Source source = Source.EMPTY; + FieldAttribute parent = maxDepth == 0 || randomBoolean() ? null : createFieldAttribute(maxDepth - 1); + String name = randomAlphaOfLength(5); + DataType type = randomFrom(DataTypes.types()); + EsField field = AbstractEsFieldTypeTests.randomAnyEsField(maxDepth); + String qualifier = randomBoolean() ? null : randomAlphaOfLength(3); + Nullability nullability = randomFrom(Nullability.values()); + boolean synthetic = randomBoolean(); + return new FieldAttribute(source, parent, name, type, field, qualifier, nullability, new NameId(), synthetic); + } + + @Override + protected FieldAttribute create() { + return createFieldAttribute(3); + } + + @Override + protected FieldAttribute mutate(FieldAttribute instance) { + Source source = instance.source(); + FieldAttribute parent = instance.parent(); + String name = instance.name(); + DataType type = instance.dataType(); + EsField field = instance.field(); + String qualifier = instance.qualifier(); + Nullability nullability = instance.nullable(); + boolean synthetic = instance.synthetic(); + switch (between(0, 6)) { + case 0 -> parent = randomValueOtherThan(parent, () -> randomBoolean() ? null : createFieldAttribute(2)); + case 1 -> name = randomAlphaOfLength(name.length() + 1); + case 2 -> type = randomValueOtherThan(type, () -> randomFrom(DataTypes.types())); + case 3 -> field = randomValueOtherThan(field, () -> AbstractEsFieldTypeTests.randomAnyEsField(3)); + case 4 -> qualifier = randomValueOtherThan(qualifier, () -> randomBoolean() ? null : randomAlphaOfLength(3)); + case 5 -> nullability = randomValueOtherThan(nullability, () -> randomFrom(Nullability.values())); + case 6 -> synthetic = false == synthetic; + } + return new FieldAttribute(source, parent, name, type, field, qualifier, nullability, new NameId(), synthetic); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/MetadataAttributeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/MetadataAttributeTests.java new file mode 100644 index 0000000000000..16a83b42d10ab --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/MetadataAttributeTests.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function; + +import org.elasticsearch.xpack.esql.core.expression.MetadataAttribute; +import org.elasticsearch.xpack.esql.core.expression.NameId; +import org.elasticsearch.xpack.esql.core.expression.Nullability; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.DataTypes; + +public class MetadataAttributeTests extends AbstractAttributeTestCase { + @Override + protected MetadataAttribute create() { + Source source = Source.EMPTY; + String name = randomAlphaOfLength(5); + DataType type = randomFrom(DataTypes.types()); + String qualifier = randomBoolean() ? null : randomAlphaOfLength(3); + Nullability nullability = randomFrom(Nullability.values()); + boolean synthetic = randomBoolean(); + boolean searchable = randomBoolean(); + return new MetadataAttribute(source, name, type, qualifier, nullability, new NameId(), synthetic, searchable); + } + + @Override + protected MetadataAttribute mutate(MetadataAttribute instance) { + Source source = instance.source(); + String name = instance.name(); + DataType type = instance.dataType(); + String qualifier = instance.qualifier(); + Nullability nullability = instance.nullable(); + boolean synthetic = instance.synthetic(); + boolean searchable = instance.searchable(); + switch (between(0, 5)) { + case 0 -> name = randomAlphaOfLength(name.length() + 1); + case 1 -> type = randomValueOtherThan(type, () -> randomFrom(DataTypes.types())); + case 2 -> qualifier = randomValueOtherThan(qualifier, () -> randomBoolean() ? null : randomAlphaOfLength(3)); + case 3 -> nullability = randomValueOtherThan(nullability, () -> randomFrom(Nullability.values())); + case 4 -> synthetic = false == synthetic; + case 5 -> searchable = false == searchable; + } + return new MetadataAttribute(source, name, type, qualifier, nullability, new NameId(), synthetic, searchable); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/ReferenceAttributeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/ReferenceAttributeTests.java new file mode 100644 index 0000000000000..e248b741ff48d --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/ReferenceAttributeTests.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function; + +import org.elasticsearch.xpack.esql.core.expression.NameId; +import org.elasticsearch.xpack.esql.core.expression.Nullability; +import org.elasticsearch.xpack.esql.core.expression.ReferenceAttribute; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.DataType; +import org.elasticsearch.xpack.esql.core.type.DataTypes; + +public class ReferenceAttributeTests extends AbstractAttributeTestCase { + @Override + protected ReferenceAttribute create() { + Source source = Source.EMPTY; + String name = randomAlphaOfLength(5); + DataType type = randomFrom(DataTypes.types()); + String qualifier = randomBoolean() ? null : randomAlphaOfLength(3); + Nullability nullability = randomFrom(Nullability.values()); + boolean synthetic = randomBoolean(); + return new ReferenceAttribute(source, name, type, qualifier, nullability, new NameId(), synthetic); + } + + @Override + protected ReferenceAttribute mutate(ReferenceAttribute instance) { + Source source = instance.source(); + String name = instance.name(); + DataType type = instance.dataType(); + String qualifier = instance.qualifier(); + Nullability nullability = instance.nullable(); + boolean synthetic = instance.synthetic(); + switch (between(0, 4)) { + case 0 -> name = randomAlphaOfLength(name.length() + 1); + case 1 -> type = randomValueOtherThan(type, () -> randomFrom(DataTypes.types())); + case 2 -> qualifier = randomValueOtherThan(qualifier, () -> randomBoolean() ? null : randomAlphaOfLength(3)); + case 3 -> nullability = randomValueOtherThan(nullability, () -> randomFrom(Nullability.values())); + case 4 -> synthetic = false == synthetic; + } + return new ReferenceAttribute(source, name, type, qualifier, nullability, new NameId(), synthetic); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttributeTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttributeTests.java new file mode 100644 index 0000000000000..e195f31664774 --- /dev/null +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/expression/function/UnsupportedAttributeTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.esql.expression.function; + +import org.elasticsearch.xpack.esql.core.expression.NameId; +import org.elasticsearch.xpack.esql.core.tree.Source; +import org.elasticsearch.xpack.esql.core.type.UnsupportedEsField; +import org.elasticsearch.xpack.esql.core.type.UnsupportedEsFieldTests; + +public class UnsupportedAttributeTests extends AbstractAttributeTestCase { + @Override + protected UnsupportedAttribute create() { + String name = randomAlphaOfLength(5); + UnsupportedEsField field = UnsupportedEsFieldTests.randomUnsupportedEsField(4); + String customMessage = randomBoolean() ? null : randomAlphaOfLength(9); + NameId id = new NameId(); + return new UnsupportedAttribute(Source.EMPTY, name, field, customMessage, id); + } + + @Override + protected UnsupportedAttribute mutate(UnsupportedAttribute instance) { + Source source = instance.source(); + String name = instance.name(); + UnsupportedEsField field = instance.field(); + String customMessage = instance.hasCustomMessage() ? instance.unresolvedMessage() : null; + switch (between(0, 2)) { + case 0 -> name = randomAlphaOfLength(name.length() + 1); + case 1 -> field = randomValueOtherThan(field, () -> UnsupportedEsFieldTests.randomUnsupportedEsField(4)); + case 2 -> customMessage = randomValueOtherThan(customMessage, () -> randomBoolean() ? null : randomAlphaOfLength(9)); + default -> throw new IllegalArgumentException(); + } + return new UnsupportedAttribute(source, name, field, customMessage, new NameId()); + } +} diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java index c1a7f1219c5f7..9cb4b6cff3fc0 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/io/stream/PlanNamedTypesTests.java @@ -22,7 +22,6 @@ import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NameId; -import org.elasticsearch.xpack.esql.core.expression.NamedExpression; import org.elasticsearch.xpack.esql.core.expression.Nullability; import org.elasticsearch.xpack.esql.core.expression.function.Function; import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.ArithmeticOperation; @@ -219,52 +218,6 @@ public void testWrappedStreamSimple() throws IOException { assertThat(in.readVInt(), equalTo(11_345)); } - public void testUnsupportedAttributeSimple() throws IOException { - var orig = new UnsupportedAttribute( - Source.EMPTY, - "foo", - new UnsupportedEsField("foo", "keyword"), - "field not supported", - new NameId() - ); - BytesStreamOutput bso = new BytesStreamOutput(); - PlanStreamOutput out = new PlanStreamOutput(bso, planNameRegistry, null); - PlanNamedTypes.writeUnsupportedAttr(out, orig); - var in = planStreamInput(bso); - var deser = PlanNamedTypes.readUnsupportedAttr(in); - EqualsHashCodeTestUtils.checkEqualsAndHashCode(orig, unused -> deser); - assertThat(deser.id(), equalTo(in.mapNameId(Long.parseLong(orig.id().toString())))); - } - - public void testUnsupportedAttribute() { - Stream.generate(PlanNamedTypesTests::randomUnsupportedAttribute).limit(100).forEach(PlanNamedTypesTests::assertNamedExpression); - } - - public void testFieldAttributeSimple() throws IOException { - var orig = new FieldAttribute( - Source.EMPTY, - null, // parent, can be null - "bar", // name - DataTypes.KEYWORD, - randomEsField(), - null, // qualifier, can be null - Nullability.TRUE, - new NameId(), - true // synthetic - ); - BytesStreamOutput bso = new BytesStreamOutput(); - PlanStreamOutput out = new PlanStreamOutput(bso, planNameRegistry, null); - PlanNamedTypes.writeFieldAttribute(out, orig); - var in = planStreamInput(bso); - var deser = PlanNamedTypes.readFieldAttribute(in); - EqualsHashCodeTestUtils.checkEqualsAndHashCode(orig, unused -> deser); - assertThat(deser.id(), equalTo(in.mapNameId(Long.parseLong(orig.id().toString())))); - } - - public void testFieldAttribute() { - Stream.generate(PlanNamedTypesTests::randomFieldAttribute).limit(100).forEach(PlanNamedTypesTests::assertNamedExpression); - } - public void testBinComparisonSimple() throws IOException { var orig = new Equals(Source.EMPTY, field("foo", DataTypes.DOUBLE), field("bar", DataTypes.DOUBLE)); BytesStreamOutput bso = new BytesStreamOutput(); @@ -445,11 +398,6 @@ public void testMvExpand() throws IOException { EqualsHashCodeTestUtils.checkEqualsAndHashCode(orig, unused -> deser); } - private static void assertNamedExpression(NamedExpression origObj) { - var deserObj = serializeDeserialize(origObj, PlanStreamOutput::writeExpression, PlanStreamInput::readNamedExpression); - EqualsHashCodeTestUtils.checkEqualsAndHashCode(origObj, unused -> deserObj); - } - private static void assertNamedType(Class type, T origObj) { var deserObj = serializeDeserialize(origObj, (o, v) -> o.writeNamed(type, origObj), i -> i.readNamed(type)); EqualsHashCodeTestUtils.checkEqualsAndHashCode(origObj, unused -> deserObj); diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index cfe2535519557..4bb797faff04c 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -4554,7 +4554,6 @@ public void testSimplifyComparisonArithmeticCommutativeVsNonCommutativeOps() { } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108388") public void testSimplifyComparisonArithmeticsWithFloatingPoints() { doTestSimplifyComparisonArithmetics("float / 2 > 4", "float", GT, 8d); } @@ -4578,17 +4577,14 @@ public void testSimplifyComparisonArithmeticWithMultipleOps() { doTestSimplifyComparisonArithmetics("((integer + 1) * 2 - 4) * 4 >= 16", "integer", GTE, 3); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108743") public void testSimplifyComparisonArithmeticWithFieldNegation() { doTestSimplifyComparisonArithmetics("12 * (-integer - 5) >= -120", "integer", LTE, 5); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108743") public void testSimplifyComparisonArithmeticWithFieldDoubleNegation() { doTestSimplifyComparisonArithmetics("12 * -(-integer - 5) <= 120", "integer", LTE, 5); } - @AwaitsFix(bugUrl = "https://github.com/elastic/elasticsearch/issues/108743") public void testSimplifyComparisonArithmeticWithConjunction() { doTestSimplifyComparisonArithmetics("12 * (-integer - 5) == -120 AND integer < 6 ", "integer", EQ, 5); } diff --git a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/plan/QueryPlanTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java similarity index 96% rename from x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/plan/QueryPlanTests.java rename to x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java index 747823795d408..11c2d9532ff16 100644 --- a/x-pack/plugin/esql-core/src/test/java/org/elasticsearch/xpack/esql/core/plan/QueryPlanTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plan/QueryPlanTests.java @@ -5,7 +5,7 @@ * 2.0. */ -package org.elasticsearch.xpack.esql.core.plan; +package org.elasticsearch.xpack.esql.plan; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.esql.core.expression.Alias; @@ -13,12 +13,12 @@ import org.elasticsearch.xpack.esql.core.expression.FieldAttribute; import org.elasticsearch.xpack.esql.core.expression.Literal; import org.elasticsearch.xpack.esql.core.expression.NamedExpression; -import org.elasticsearch.xpack.esql.core.expression.predicate.operator.arithmetic.Add; import org.elasticsearch.xpack.esql.core.plan.logical.Filter; import org.elasticsearch.xpack.esql.core.plan.logical.Limit; import org.elasticsearch.xpack.esql.core.plan.logical.LogicalPlan; import org.elasticsearch.xpack.esql.core.plan.logical.OrderBy; -import org.elasticsearch.xpack.esql.core.plan.logical.Project; +import org.elasticsearch.xpack.esql.expression.predicate.operator.arithmetic.Add; +import org.elasticsearch.xpack.esql.plan.logical.Project; import java.util.ArrayList; import java.util.List; diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java index 2b2a4e76c5cc8..dde39b66664de 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/plugin/DataNodeRequestTests.java @@ -10,7 +10,6 @@ import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.io.stream.Writeable; import org.elasticsearch.common.settings.Settings; -import org.elasticsearch.compute.data.Block; import org.elasticsearch.index.Index; import org.elasticsearch.index.query.TermQueryBuilder; import org.elasticsearch.index.shard.ShardId; @@ -58,8 +57,7 @@ protected Writeable.Reader instanceReader() { protected NamedWriteableRegistry getNamedWriteableRegistry() { List writeables = new ArrayList<>(); writeables.addAll(new SearchModule(Settings.EMPTY, List.of()).getNamedWriteables()); - writeables.addAll(Block.getNamedWriteables()); - writeables.addAll(EsField.getNamedWriteables()); + writeables.addAll(new EsqlPlugin().getNamedWriteables()); return new NamedWriteableRegistry(writeables); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index edea0104ded16..bff7ecdcc4a07 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -45,9 +45,10 @@ import org.elasticsearch.xpack.inference.services.elser.ElserInternalServiceSettings; import org.elasticsearch.xpack.inference.services.elser.ElserMlNodeTaskSettings; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionServiceSettings; +import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceServiceSettings; -import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserSecretSettings; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings; +import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionServiceSettings; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionTaskSettings; import org.elasticsearch.xpack.inference.services.openai.embeddings.OpenAiEmbeddingsServiceSettings; @@ -108,10 +109,23 @@ public static List getNamedWriteables() { addAzureOpenAiNamedWriteables(namedWriteables); addAzureAiStudioNamedWriteables(namedWriteables); addGoogleAiStudioNamedWritables(namedWriteables); + addMistralNamedWriteables(namedWriteables); return namedWriteables; } + private static void addMistralNamedWriteables(List namedWriteables) { + namedWriteables.add( + new NamedWriteableRegistry.Entry( + ServiceSettings.class, + MistralEmbeddingsServiceSettings.NAME, + MistralEmbeddingsServiceSettings::new + ) + ); + + // note - no task settings for Mistral embeddings... + } + private static void addAzureAiStudioNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry( @@ -251,9 +265,6 @@ private static void addHuggingFaceNamedWriteables(List namedWriteables) { @@ -264,6 +275,13 @@ private static void addGoogleAiStudioNamedWritables(List namedWriteables) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 9a9b11fd1400e..1e0f715e3f3e9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -74,6 +74,7 @@ import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioService; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceService; import org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserService; +import org.elasticsearch.xpack.inference.services.mistral.MistralService; import org.elasticsearch.xpack.inference.services.openai.OpenAiService; import java.util.ArrayList; @@ -198,6 +199,7 @@ public List getInferenceServiceFactories() { context -> new AzureOpenAiService(httpFactory.get(), serviceComponents.get()), context -> new AzureAiStudioService(httpFactory.get(), serviceComponents.get()), context -> new GoogleAiStudioService(httpFactory.get(), serviceComponents.get()), + context -> new MistralService(httpFactory.get(), serviceComponents.get()), ElasticsearchInternalService::new ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionCreator.java index 51a8cc7a0bd56..86154faefabc5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionCreator.java @@ -11,6 +11,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; +import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsModel; import java.util.Map; import java.util.Objects; @@ -31,4 +32,9 @@ public ExecutableAction create(GoogleAiStudioCompletionModel model, Map taskSettings) { + return new GoogleAiStudioEmbeddingsAction(sender, model, serviceComponents); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionVisitor.java index 090d3f9a69710..2e89200cce53b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionVisitor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioActionVisitor.java @@ -9,6 +9,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; +import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsModel; import java.util.Map; @@ -16,4 +17,5 @@ public interface GoogleAiStudioActionVisitor { ExecutableAction create(GoogleAiStudioCompletionModel model, Map taskSettings); + ExecutableAction create(GoogleAiStudioEmbeddingsModel model, Map taskSettings); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsAction.java new file mode 100644 index 0000000000000..5ce780193c789 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsAction.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.googleaistudio; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.GoogleAiStudioEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsModel; + +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class GoogleAiStudioEmbeddingsAction implements ExecutableAction { + + private final String failedToSendRequestErrorMessage; + + private final GoogleAiStudioEmbeddingsRequestManager requestManager; + + private final Sender sender; + + public GoogleAiStudioEmbeddingsAction(Sender sender, GoogleAiStudioEmbeddingsModel model, ServiceComponents serviceComponents) { + Objects.requireNonNull(serviceComponents); + Objects.requireNonNull(model); + this.sender = Objects.requireNonNull(sender); + this.requestManager = new GoogleAiStudioEmbeddingsRequestManager( + model, + serviceComponents.truncator(), + serviceComponents.threadPool() + ); + this.failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(model.uri(), "Google AI Studio embeddings"); + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException( + failedToSendRequestErrorMessage, + listener + ); + + sender.send(requestManager, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, failedToSendRequestErrorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/mistral/MistralAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/mistral/MistralAction.java new file mode 100644 index 0000000000000..f7b51e80a04b3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/mistral/MistralAction.java @@ -0,0 +1,45 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.mistral; + +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.MistralEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.createInternalServerError; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.wrapFailuresInElasticsearchException; + +public class MistralAction implements ExecutableAction { + protected final Sender sender; + protected final MistralEmbeddingsRequestManager requestCreator; + protected final String errorMessage; + + protected MistralAction(Sender sender, MistralEmbeddingsRequestManager requestCreator, String errorMessage) { + this.sender = sender; + this.requestCreator = requestCreator; + this.errorMessage = errorMessage; + } + + @Override + public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { + try { + ActionListener wrappedListener = wrapFailuresInElasticsearchException(errorMessage, listener); + + sender.send(requestCreator, inferenceInputs, timeout, wrappedListener); + } catch (ElasticsearchException e) { + listener.onFailure(e); + } catch (Exception e) { + listener.onFailure(createInternalServerError(e, errorMessage)); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/mistral/MistralActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/mistral/MistralActionCreator.java new file mode 100644 index 0000000000000..a023973ea6aa5 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/mistral/MistralActionCreator.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.mistral; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.http.sender.MistralEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel; + +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; + +public class MistralActionCreator implements MistralActionVisitor { + private final Sender sender; + private final ServiceComponents serviceComponents; + + public MistralActionCreator(Sender sender, ServiceComponents serviceComponents) { + this.sender = Objects.requireNonNull(sender); + this.serviceComponents = Objects.requireNonNull(serviceComponents); + } + + @Override + public ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map taskSettings) { + var requestManager = new MistralEmbeddingsRequestManager( + embeddingsModel, + serviceComponents.truncator(), + serviceComponents.threadPool() + ); + var errorMessage = constructFailedToSendRequestMessage(embeddingsModel.uri(), "Mistral embeddings"); + return new MistralAction(sender, requestManager, errorMessage); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/mistral/MistralActionVisitor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/mistral/MistralActionVisitor.java new file mode 100644 index 0000000000000..3764efeb0f6c8 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/mistral/MistralActionVisitor.java @@ -0,0 +1,17 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.mistral; + +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel; + +import java.util.Map; + +public interface MistralActionVisitor { + ExecutableAction create(MistralEmbeddingsModel embeddingsModel, Map taskSettings); +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java index 76ef37592d88e..deff410aebaa8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java @@ -16,8 +16,8 @@ import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioChatCompletionRequest; -import org.elasticsearch.xpack.inference.external.response.AzureAndOpenAiErrorResponseEntity; -import org.elasticsearch.xpack.inference.external.response.AzureAndOpenAiExternalResponseHandler; +import org.elasticsearch.xpack.inference.external.response.AzureMistralOpenAiErrorResponseEntity; +import org.elasticsearch.xpack.inference.external.response.AzureMistralOpenAiExternalResponseHandler; import org.elasticsearch.xpack.inference.external.response.azureaistudio.AzureAiStudioChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.azureaistudio.completion.AzureAiStudioChatCompletionModel; @@ -51,10 +51,10 @@ public Runnable create( } private static ResponseHandler createCompletionHandler() { - return new AzureAndOpenAiExternalResponseHandler( + return new AzureMistralOpenAiExternalResponseHandler( "azure ai studio completion", new AzureAiStudioChatCompletionResponseEntity(), - AzureAndOpenAiErrorResponseEntity::fromResponse + AzureMistralOpenAiErrorResponseEntity::fromResponse ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java index c2edc79dfe937..a2b363151a417 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioEmbeddingsRequestManager.java @@ -17,8 +17,8 @@ import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.request.azureaistudio.AzureAiStudioEmbeddingsRequest; -import org.elasticsearch.xpack.inference.external.response.AzureAndOpenAiErrorResponseEntity; -import org.elasticsearch.xpack.inference.external.response.AzureAndOpenAiExternalResponseHandler; +import org.elasticsearch.xpack.inference.external.response.AzureMistralOpenAiErrorResponseEntity; +import org.elasticsearch.xpack.inference.external.response.AzureMistralOpenAiExternalResponseHandler; import org.elasticsearch.xpack.inference.external.response.azureaistudio.AzureAiStudioEmbeddingsResponseEntity; import org.elasticsearch.xpack.inference.services.azureaistudio.embeddings.AzureAiStudioEmbeddingsModel; @@ -55,10 +55,10 @@ public Runnable create( } private static ResponseHandler createEmbeddingsHandler() { - return new AzureAndOpenAiExternalResponseHandler( + return new AzureMistralOpenAiExternalResponseHandler( "azure ai studio text embedding", new AzureAiStudioEmbeddingsResponseEntity(), - AzureAndOpenAiErrorResponseEntity::fromResponse + AzureMistralOpenAiErrorResponseEntity::fromResponse ); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java new file mode 100644 index 0000000000000..15c2825e7d043 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioEmbeddingsRequestManager.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.googleaistudio.GoogleAiStudioResponseHandler; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.googleaistudio.GoogleAiStudioEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +public class GoogleAiStudioEmbeddingsRequestManager extends GoogleAiStudioRequestManager { + + private static final Logger logger = LogManager.getLogger(GoogleAiStudioEmbeddingsRequestManager.class); + + private static final ResponseHandler HANDLER = createEmbeddingsHandler(); + + private static ResponseHandler createEmbeddingsHandler() { + return new GoogleAiStudioResponseHandler("google ai studio embeddings", GoogleAiStudioEmbeddingsResponseEntity::fromResponse); + } + + private final GoogleAiStudioEmbeddingsModel model; + + private final Truncator truncator; + + public GoogleAiStudioEmbeddingsRequestManager(GoogleAiStudioEmbeddingsModel model, Truncator truncator, ThreadPool threadPool) { + super(threadPool, model); + this.model = Objects.requireNonNull(model); + this.truncator = Objects.requireNonNull(truncator); + } + + @Override + public Runnable create( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + HttpClientContext context, + ActionListener listener + ) { + var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + GoogleAiStudioEmbeddingsRequest request = new GoogleAiStudioEmbeddingsRequest(truncator, truncatedInput, model); + + return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java new file mode 100644 index 0000000000000..f31a633581705 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/MistralEmbeddingsRequestManager.java @@ -0,0 +1,75 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.http.client.protocol.HttpClientContext; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.request.mistral.MistralEmbeddingsRequest; +import org.elasticsearch.xpack.inference.external.response.AzureMistralOpenAiErrorResponseEntity; +import org.elasticsearch.xpack.inference.external.response.AzureMistralOpenAiExternalResponseHandler; +import org.elasticsearch.xpack.inference.external.response.mistral.MistralEmbeddingsResponseEntity; +import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel; + +import java.util.List; +import java.util.Objects; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.inference.common.Truncator.truncate; + +public class MistralEmbeddingsRequestManager extends BaseRequestManager { + private static final Logger logger = LogManager.getLogger(AzureOpenAiEmbeddingsRequestManager.class); + private static final ResponseHandler HANDLER = createEmbeddingsHandler(); + + private final Truncator truncator; + private final MistralEmbeddingsModel model; + + private static ResponseHandler createEmbeddingsHandler() { + return new AzureMistralOpenAiExternalResponseHandler( + "mistral text embedding", + new MistralEmbeddingsResponseEntity(), + AzureMistralOpenAiErrorResponseEntity::fromResponse + ); + } + + public MistralEmbeddingsRequestManager(MistralEmbeddingsModel model, Truncator truncator, ThreadPool threadPool) { + super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitSettings()); + this.model = Objects.requireNonNull(model); + this.truncator = Objects.requireNonNull(truncator); + + } + + @Override + public Runnable create( + String query, + List input, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + HttpClientContext context, + ActionListener listener + ) { + var truncatedInput = truncate(input, model.getServiceSettings().maxInputTokens()); + MistralEmbeddingsRequest request = new MistralEmbeddingsRequest(truncator, truncatedInput, model); + + return new ExecutableInferenceRequest(requestSender, logger, request, context, HANDLER, hasRequestCompletedFunction, listener); + } + + record RateLimitGrouping(int keyHashCode) { + public static RateLimitGrouping of(MistralEmbeddingsModel model) { + Objects.requireNonNull(model); + + return new RateLimitGrouping(model.getSecretSettings().apiKey().hashCode()); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioEmbeddingsRequest.java new file mode 100644 index 0000000000000..a96cbf2afb27a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioEmbeddingsRequest.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.Objects; + +public class GoogleAiStudioEmbeddingsRequest implements GoogleAiStudioRequest { + + private final Truncator truncator; + + private final Truncator.TruncationResult truncationResult; + + private final GoogleAiStudioEmbeddingsModel model; + + public GoogleAiStudioEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, GoogleAiStudioEmbeddingsModel model) { + this.truncator = Objects.requireNonNull(truncator); + this.truncationResult = Objects.requireNonNull(input); + this.model = Objects.requireNonNull(model); + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(model.uri()); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString( + new GoogleAiStudioEmbeddingsRequestEntity( + truncationResult.input(), + model.getServiceSettings().modelId(), + model.getServiceSettings().dimensions() + ) + ).getBytes(StandardCharsets.UTF_8) + ); + + httpPost.setEntity(byteEntity); + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + + GoogleAiStudioRequest.decorateWithApiKeyParameter(httpPost, model.getSecretSettings()); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public String getInferenceEntityId() { + return model.getInferenceEntityId(); + } + + @Override + public URI getURI() { + return model.uri(); + } + + @Override + public Request truncate() { + var truncatedInput = truncator.truncate(truncationResult.input()); + + return new GoogleAiStudioEmbeddingsRequest(truncator, truncatedInput, model); + } + + @Override + public boolean[] getTruncationInfo() { + return truncationResult.truncated().clone(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..9d40f1cf097ec --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioEmbeddingsRequestEntity.java @@ -0,0 +1,78 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.core.Strings.format; + +public record GoogleAiStudioEmbeddingsRequestEntity(List inputs, String model, @Nullable Integer dimensions) + implements + ToXContentObject { + + private static final String REQUESTS_FIELD = "requests"; + private static final String MODEL_FIELD = "model"; + + private static final String MODELS_PREFIX = "models"; + private static final String CONTENT_FIELD = "content"; + private static final String PARTS_FIELD = "parts"; + private static final String TEXT_FIELD = "text"; + + private static final String OUTPUT_DIMENSIONALITY_FIELD = "outputDimensionality"; + + public GoogleAiStudioEmbeddingsRequestEntity { + Objects.requireNonNull(inputs); + Objects.requireNonNull(model); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startArray(REQUESTS_FIELD); + + for (String input : inputs) { + builder.startObject(); + builder.field(MODEL_FIELD, format("%s/%s", MODELS_PREFIX, model)); + + { + builder.startObject(CONTENT_FIELD); + + { + builder.startArray(PARTS_FIELD); + + { + builder.startObject(); + builder.field(TEXT_FIELD, input); + builder.endObject(); + } + + builder.endArray(); + } + + builder.endObject(); + } + + if (dimensions != null) { + builder.field(OUTPUT_DIMENSIONALITY_FIELD, dimensions); + } + + builder.endObject(); + } + + builder.endArray(); + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequest.java index ede9c6193aa21..fb99deabc9c5e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequest.java @@ -11,13 +11,13 @@ import org.apache.http.client.utils.URIBuilder; import org.elasticsearch.common.ValidationException; import org.elasticsearch.xpack.inference.external.request.Request; -import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; public interface GoogleAiStudioRequest extends Request { String API_KEY_PARAMETER = "key"; - static void decorateWithApiKeyParameter(HttpPost httpPost, GoogleAiStudioSecretSettings secretSettings) { + static void decorateWithApiKeyParameter(HttpPost httpPost, DefaultSecretSettings secretSettings) { try { var uri = httpPost.getURI(); var uriWithApiKey = new URIBuilder().setScheme(uri.getScheme()) diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioUtils.java index d63a0bbe2af91..81ad5b6203682 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioUtils.java @@ -17,6 +17,8 @@ public class GoogleAiStudioUtils { public static final String GENERATE_CONTENT_ACTION = "generateContent"; + public static final String BATCH_EMBED_CONTENTS_ACTION = "batchEmbedContents"; + private GoogleAiStudioUtils() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/mistral/MistralEmbeddingsRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/mistral/MistralEmbeddingsRequest.java new file mode 100644 index 0000000000000..e1c90c08f643f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/mistral/MistralEmbeddingsRequest.java @@ -0,0 +1,76 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.mistral; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.entity.ByteArrayEntity; +import org.elasticsearch.common.Strings; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel; + +import java.net.URI; +import java.nio.charset.StandardCharsets; + +import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; + +public class MistralEmbeddingsRequest implements Request { + private final URI uri; + private final MistralEmbeddingsModel embeddingsModel; + private final String inferenceEntityId; + private final Truncator.TruncationResult truncationResult; + private final Truncator truncator; + + public MistralEmbeddingsRequest(Truncator truncator, Truncator.TruncationResult input, MistralEmbeddingsModel model) { + this.uri = model.uri(); + this.embeddingsModel = model; + this.inferenceEntityId = model.getInferenceEntityId(); + this.truncator = truncator; + this.truncationResult = input; + } + + @Override + public HttpRequest createHttpRequest() { + HttpPost httpPost = new HttpPost(this.uri); + + ByteArrayEntity byteEntity = new ByteArrayEntity( + Strings.toString(new MistralEmbeddingsRequestEntity(embeddingsModel.model(), truncationResult.input())) + .getBytes(StandardCharsets.UTF_8) + ); + httpPost.setEntity(byteEntity); + + httpPost.setHeader(HttpHeaders.CONTENT_TYPE, XContentType.JSON.mediaType()); + httpPost.setHeader(createAuthBearerHeader(embeddingsModel.getSecretSettings().apiKey())); + + return new HttpRequest(httpPost, getInferenceEntityId()); + } + + @Override + public URI getURI() { + return uri; + } + + @Override + public Request truncate() { + var truncatedInput = truncator.truncate(truncationResult.input()); + return new MistralEmbeddingsRequest(truncator, truncatedInput, embeddingsModel); + } + + @Override + public boolean[] getTruncationInfo() { + return truncationResult.truncated().clone(); + } + + @Override + public String getInferenceEntityId() { + return inferenceEntityId; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/mistral/MistralEmbeddingsRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/mistral/MistralEmbeddingsRequestEntity.java new file mode 100644 index 0000000000000..d852e9ee34046 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/mistral/MistralEmbeddingsRequestEntity.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.mistral; + +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.List; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.ENCODING_FORMAT_FIELD; +import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.INPUT_FIELD; +import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD; + +public record MistralEmbeddingsRequestEntity(String model, List input) implements ToXContentObject { + public MistralEmbeddingsRequestEntity { + Objects.requireNonNull(model); + Objects.requireNonNull(input); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + builder.field(MODEL_FIELD, model); + builder.field(INPUT_FIELD, input); + builder.field(ENCODING_FORMAT_FIELD, "float"); + + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiErrorResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiErrorResponseEntity.java similarity index 90% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiErrorResponseEntity.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiErrorResponseEntity.java index 4ac77d6df3c33..83ea7801dfd58 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiErrorResponseEntity.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiErrorResponseEntity.java @@ -31,10 +31,10 @@ * This currently covers error handling for Azure AI Studio, however this pattern * can be used to simplify and refactor handling for Azure OpenAI and OpenAI responses. */ -public class AzureAndOpenAiErrorResponseEntity implements ErrorMessage { +public class AzureMistralOpenAiErrorResponseEntity implements ErrorMessage { protected String errorMessage; - public AzureAndOpenAiErrorResponseEntity(String errorMessage) { + public AzureMistralOpenAiErrorResponseEntity(String errorMessage) { this.errorMessage = errorMessage; } @@ -62,7 +62,7 @@ public static ErrorMessage fromResponse(HttpResult response) { if (error != null) { var message = (String) error.get("message"); if (message != null) { - return new AzureAndOpenAiErrorResponseEntity(message); + return new AzureMistralOpenAiErrorResponseEntity(message); } } } catch (Exception e) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiExternalResponseHandler.java similarity index 96% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandler.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiExternalResponseHandler.java index 5f803ad6fe74e..dfdb6712d5e45 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandler.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/AzureMistralOpenAiExternalResponseHandler.java @@ -29,7 +29,7 @@ * This currently covers response handling for Azure AI Studio, however this pattern * can be used to simplify and refactor handling for Azure OpenAI and OpenAI responses. */ -public class AzureAndOpenAiExternalResponseHandler extends BaseResponseHandler { +public class AzureMistralOpenAiExternalResponseHandler extends BaseResponseHandler { // The maximum number of requests that are permitted before exhausting the rate limit. static final String REQUESTS_LIMIT = "x-ratelimit-limit-requests"; @@ -43,7 +43,7 @@ public class AzureAndOpenAiExternalResponseHandler extends BaseResponseHandler { static final String CONTENT_TOO_LARGE_MESSAGE = "Please reduce your prompt; or completion length."; static final String SERVER_BUSY_ERROR = "Received a server busy error status code"; - public AzureAndOpenAiExternalResponseHandler( + public AzureMistralOpenAiExternalResponseHandler( String requestType, ResponseParser parseFunction, Function errorParseFunction @@ -116,7 +116,7 @@ public static boolean isContentTooLarge(HttpResult result) { } if (statusCode == 400) { - var errorEntity = AzureAndOpenAiErrorResponseEntity.fromResponse(result); + var errorEntity = AzureMistralOpenAiErrorResponseEntity.fromResponse(result); return errorEntity != null && errorEntity.getErrorMessage().contains(CONTENT_TOO_LARGE_MESSAGE); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioEmbeddingsResponseEntity.java new file mode 100644 index 0000000000000..204738f2a2552 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioEmbeddingsResponseEntity.java @@ -0,0 +1,110 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.googleaistudio; + +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.consumeUntilObjectEnd; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.positionParserAtTokenAfterField; + +public class GoogleAiStudioEmbeddingsResponseEntity { + + private static final String FAILED_TO_FIND_FIELD_TEMPLATE = + "Failed to find required field [%s] in Google AI Studio embeddings response"; + + /** + * Parses the Google AI Studio batch embeddings response (will be used for single and batch embeddings). + * For a request like: + * + *
+     *     
+     *         {
+     *             "inputs": ["Embed this", "Embed this, too"]
+     *         }
+     *     
+     * 
+ * + * The response would look like: + * + *
+     *     
+     *  {
+     *     "embeddings": [
+     *         {
+     *             "values": [
+     *                 -0.00606332,
+     *                 0.058092743,
+     *                 -0.06390548
+     *             ]
+     *         },
+     *         {
+     *             "values": [
+     *               -0.00606332,
+     *               -0.06390548,
+     *                0.058092743
+     *             ]
+     *         }
+     *     ]
+     *  }
+     *
+     *     
+     * 
+ */ + + public static TextEmbeddingResults fromResponse(Request request, HttpResult response) throws IOException { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, response.body())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + positionParserAtTokenAfterField(jsonParser, "embeddings", FAILED_TO_FIND_FIELD_TEMPLATE); + + List embeddingList = XContentParserUtils.parseList( + jsonParser, + GoogleAiStudioEmbeddingsResponseEntity::parseEmbeddingObject + ); + + return new TextEmbeddingResults(embeddingList); + } + } + + private static TextEmbeddingResults.Embedding parseEmbeddingObject(XContentParser parser) throws IOException { + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser); + + positionParserAtTokenAfterField(parser, "values", FAILED_TO_FIND_FIELD_TEMPLATE); + + List embeddingValuesList = XContentParserUtils.parseList(parser, GoogleAiStudioEmbeddingsResponseEntity::parseEmbeddingList); + // parse and discard the rest of the object + consumeUntilObjectEnd(parser); + + return TextEmbeddingResults.Embedding.of(embeddingValuesList); + } + + private static float parseEmbeddingList(XContentParser parser) throws IOException { + XContentParser.Token token = parser.currentToken(); + XContentParserUtils.ensureExpectedToken(XContentParser.Token.VALUE_NUMBER, token, parser); + return parser.floatValue(); + } + + private GoogleAiStudioEmbeddingsResponseEntity() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/mistral/MistralEmbeddingsResponseEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/mistral/MistralEmbeddingsResponseEntity.java new file mode 100644 index 0000000000000..01de92c20a3f1 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/response/mistral/MistralEmbeddingsResponseEntity.java @@ -0,0 +1,24 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.mistral; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.BaseResponseEntity; +import org.elasticsearch.xpack.inference.external.response.openai.OpenAiEmbeddingsResponseEntity; + +import java.io.IOException; + +public class MistralEmbeddingsResponseEntity extends BaseResponseEntity { + @Override + protected InferenceServiceResults fromResponse(Request request, HttpResult response) throws IOException { + // expected response type is the same as the Open AI Embeddings + return OpenAiEmbeddingsResponseEntity.fromResponse(request, response); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioModel.java index 4ddffd0bae615..d817a3bbb73ef 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioModel.java @@ -11,6 +11,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor; @@ -31,6 +32,12 @@ public GoogleAiStudioModel( this.rateLimitServiceSettings = Objects.requireNonNull(rateLimitServiceSettings); } + public GoogleAiStudioModel(GoogleAiStudioModel model, ServiceSettings serviceSettings) { + super(model, serviceSettings); + + rateLimitServiceSettings = model.rateLimitServiceSettings(); + } + public abstract ExecutableAction accept(GoogleAiStudioActionVisitor creator, Map taskSettings, InputType inputType); public GoogleAiStudioRateLimitServiceSettings rateLimitServiceSettings() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettings.java deleted file mode 100644 index bf702d010e2a8..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettings.java +++ /dev/null @@ -1,106 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.googleaistudio; - -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.settings.SecureString; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.ModelSecrets; -import org.elasticsearch.inference.SecretSettings; -import org.elasticsearch.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.Map; -import java.util.Objects; - -import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalSecureString; - -public class GoogleAiStudioSecretSettings implements SecretSettings { - - public static final String NAME = "google_ai_studio_secret_settings"; - public static final String API_KEY = "api_key"; - - private final SecureString apiKey; - - public static GoogleAiStudioSecretSettings fromMap(@Nullable Map map) { - if (map == null) { - return null; - } - - ValidationException validationException = new ValidationException(); - SecureString secureApiKey = extractOptionalSecureString(map, API_KEY, ModelSecrets.SECRET_SETTINGS, validationException); - - if (secureApiKey == null) { - validationException.addValidationError(format("[secret_settings] must have [%s] set", API_KEY)); - } - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return new GoogleAiStudioSecretSettings(secureApiKey); - } - - public GoogleAiStudioSecretSettings(SecureString apiKey) { - Objects.requireNonNull(apiKey); - this.apiKey = apiKey; - } - - public GoogleAiStudioSecretSettings(StreamInput in) throws IOException { - this(in.readOptionalSecureString()); - } - - public SecureString apiKey() { - return apiKey; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - - if (apiKey != null) { - builder.field(API_KEY, apiKey.toString()); - } - - builder.endObject(); - return builder; - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.ML_INFERENCE_GOOGLE_AI_STUDIO_COMPLETION_ADDED; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeOptionalSecureString(apiKey); - } - - @Override - public boolean equals(Object object) { - if (this == object) return true; - if (object == null || getClass() != object.getClass()) return false; - GoogleAiStudioSecretSettings that = (GoogleAiStudioSecretSettings) object; - return Objects.equals(apiKey, that.apiKey); - } - - @Override - public int hashCode() { - return Objects.hash(apiKey); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 5d2654c072a0c..f8720448b0f4f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -20,15 +20,20 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker; import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; +import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsServiceSettings; import java.util.List; import java.util.Map; @@ -39,6 +44,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioServiceFields.EMBEDDING_MAX_BATCH_SIZE; public class GoogleAiStudioService extends SenderService { @@ -104,6 +110,14 @@ private static GoogleAiStudioModel createModel( taskSettings, secretSettings ); + case TEXT_EMBEDDING -> new GoogleAiStudioEmbeddingsModel( + inferenceEntityId, + taskType, + NAME, + serviceSettings, + taskSettings, + secretSettings + ); default -> throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); }; } @@ -168,6 +182,34 @@ public TransportVersion getMinimalSupportedVersion() { return TransportVersions.ML_INFERENCE_GOOGLE_AI_STUDIO_COMPLETION_ADDED; } + @Override + public void checkModelConfig(Model model, ActionListener listener) { + if (model instanceof GoogleAiStudioEmbeddingsModel embeddingsModel) { + ServiceUtils.getEmbeddingSize( + model, + this, + listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateModelWithEmbeddingDetails(embeddingsModel, size))) + ); + } else { + listener.onResponse(model); + } + } + + private GoogleAiStudioEmbeddingsModel updateModelWithEmbeddingDetails(GoogleAiStudioEmbeddingsModel model, int embeddingSize) { + var similarityFromModel = model.getServiceSettings().similarity(); + var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; + + GoogleAiStudioEmbeddingsServiceSettings serviceSettings = new GoogleAiStudioEmbeddingsServiceSettings( + model.getServiceSettings().modelId(), + model.getServiceSettings().maxInputTokens(), + embeddingSize, + similarityToUse, + model.getServiceSettings().rateLimitSettings() + ); + + return new GoogleAiStudioEmbeddingsModel(model, serviceSettings); + } + @Override protected void doInfer( Model model, @@ -213,6 +255,13 @@ protected void doChunkedInfer( TimeValue timeout, ActionListener> listener ) { - throw new UnsupportedOperationException("Chunked inference not supported yet for Google AI Studio"); + GoogleAiStudioModel googleAiStudioModel = (GoogleAiStudioModel) model; + var actionCreator = new GoogleAiStudioActionCreator(getSender(), getServiceComponents()); + + var batchedRequests = new EmbeddingRequestChunker(input, EMBEDDING_MAX_BATCH_SIZE).batchRequestsWithListeners(listener); + for (var request : batchedRequests) { + var action = googleAiStudioModel.accept(actionCreator, taskSettings, inputType); + action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); + } } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceFields.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceFields.java new file mode 100644 index 0000000000000..72471251fd86c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceFields.java @@ -0,0 +1,29 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio; + +public class GoogleAiStudioServiceFields { + + /** + * Didn't find any documentation on this, but provoked it through a large enough request, which returned: + * + *
+     *     
+     *         {
+     *             "error": {
+     *                  "code": 400,
+    *                      "message": "* BatchEmbedContentsRequest.requests: at most 100 requests can be in one batch\n",
+     *                   "status": "INVALID_ARGUMENT"
+     *              }
+     *          }
+     *     
+     * 
+ */ + static final int EMBEDDING_MAX_BATCH_SIZE = 100; + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java index 6a11f678158b6..eafb0c372202c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModel.java @@ -19,7 +19,7 @@ import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor; import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils; import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel; -import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import java.net.URI; import java.net.URISyntaxException; @@ -45,7 +45,7 @@ public GoogleAiStudioCompletionModel( service, GoogleAiStudioCompletionServiceSettings.fromMap(serviceSettings), EmptyTaskSettings.INSTANCE, - GoogleAiStudioSecretSettings.fromMap(secrets) + DefaultSecretSettings.fromMap(secrets) ); } @@ -56,7 +56,7 @@ public GoogleAiStudioCompletionModel( String service, GoogleAiStudioCompletionServiceSettings serviceSettings, TaskSettings taskSettings, - @Nullable GoogleAiStudioSecretSettings secrets + @Nullable DefaultSecretSettings secrets ) { super( new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), @@ -78,7 +78,7 @@ public GoogleAiStudioCompletionModel( String url, GoogleAiStudioCompletionServiceSettings serviceSettings, TaskSettings taskSettings, - @Nullable GoogleAiStudioSecretSettings secrets + @Nullable DefaultSecretSettings secrets ) { super( new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), @@ -102,8 +102,8 @@ public GoogleAiStudioCompletionServiceSettings getServiceSettings() { } @Override - public GoogleAiStudioSecretSettings getSecretSettings() { - return (GoogleAiStudioSecretSettings) super.getSecretSettings(); + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); } public static URI buildUri(String model) throws URISyntaxException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java new file mode 100644 index 0000000000000..ad106797de51b --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModel.java @@ -0,0 +1,128 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio.embeddings; + +import org.apache.http.client.utils.URIBuilder; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.googleaistudio.GoogleAiStudioActionVisitor; +import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioUtils; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; + +import static org.elasticsearch.core.Strings.format; + +public class GoogleAiStudioEmbeddingsModel extends GoogleAiStudioModel { + + private URI uri; + + public GoogleAiStudioEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + Map secrets + ) { + this( + inferenceEntityId, + taskType, + service, + GoogleAiStudioEmbeddingsServiceSettings.fromMap(serviceSettings), + EmptyTaskSettings.INSTANCE, + DefaultSecretSettings.fromMap(secrets) + ); + } + + public GoogleAiStudioEmbeddingsModel(GoogleAiStudioEmbeddingsModel model, GoogleAiStudioEmbeddingsServiceSettings serviceSettings) { + super(model, serviceSettings); + } + + // Should only be used directly for testing + GoogleAiStudioEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + GoogleAiStudioEmbeddingsServiceSettings serviceSettings, + TaskSettings taskSettings, + @Nullable DefaultSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secrets), + serviceSettings + ); + try { + this.uri = buildUri(serviceSettings.modelId()); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + // Should only be used directly for testing + GoogleAiStudioEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + String uri, + GoogleAiStudioEmbeddingsServiceSettings serviceSettings, + TaskSettings taskSettings, + @Nullable DefaultSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, taskSettings), + new ModelSecrets(secrets), + serviceSettings + ); + try { + this.uri = new URI(uri); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + @Override + public GoogleAiStudioEmbeddingsServiceSettings getServiceSettings() { + return (GoogleAiStudioEmbeddingsServiceSettings) super.getServiceSettings(); + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + public URI uri() { + return uri; + } + + @Override + public ExecutableAction accept(GoogleAiStudioActionVisitor visitor, Map taskSettings, InputType inputType) { + return visitor.create(this, taskSettings); + } + + public static URI buildUri(String model) throws URISyntaxException { + return new URIBuilder().setScheme("https") + .setHost(GoogleAiStudioUtils.HOST_SUFFIX) + .setPathSegments( + GoogleAiStudioUtils.V1, + GoogleAiStudioUtils.MODELS, + format("%s:%s", model, GoogleAiStudioUtils.BATCH_EMBED_CONTENTS_ACTION) + ) + .build(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..07d07dc533f06 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettings.java @@ -0,0 +1,197 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.TransportVersions; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioRateLimitServiceSettings; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; + +public class GoogleAiStudioEmbeddingsServiceSettings extends FilteredXContentObject + implements + ServiceSettings, + GoogleAiStudioRateLimitServiceSettings { + + public static final String NAME = "google_ai_studio_embeddings_service_settings"; + + /** + * Rate limits are defined at Google Gemini API Pricing. + * For pay-as-you-go you've 360 requests per minute. + */ + private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(360); + + public static GoogleAiStudioEmbeddingsServiceSettings fromMap(Map map) { + ValidationException validationException = new ValidationException(); + + String model = extractRequiredString(map, MODEL_ID, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer maxInputTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + SimilarityMeasure similarityMeasure = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer dims = extractOptionalPositiveInteger(map, DIMENSIONS, ModelConfigurations.SERVICE_SETTINGS, validationException); + RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new GoogleAiStudioEmbeddingsServiceSettings(model, maxInputTokens, dims, similarityMeasure, rateLimitSettings); + } + + private final String modelId; + + private final RateLimitSettings rateLimitSettings; + + private final Integer dims; + + private final Integer maxInputTokens; + + private final SimilarityMeasure similarity; + + public GoogleAiStudioEmbeddingsServiceSettings( + String modelId, + @Nullable Integer maxInputTokens, + @Nullable Integer dims, + @Nullable SimilarityMeasure similarity, + @Nullable RateLimitSettings rateLimitSettings + ) { + this.modelId = modelId; + this.maxInputTokens = maxInputTokens; + this.dims = dims; + this.similarity = similarity; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + public GoogleAiStudioEmbeddingsServiceSettings(StreamInput in) throws IOException { + this.modelId = in.readString(); + this.maxInputTokens = in.readOptionalVInt(); + this.dims = in.readOptionalVInt(); + this.similarity = in.readOptionalEnum(SimilarityMeasure.class); + this.rateLimitSettings = new RateLimitSettings(in); + } + + @Override + public String modelId() { + return modelId; + } + + @Override + public RateLimitSettings rateLimitSettings() { + return rateLimitSettings; + } + + public Integer maxInputTokens() { + return maxInputTokens; + } + + @Override + public Integer dimensions() { + return dims; + } + + @Override + public SimilarityMeasure similarity() { + return similarity; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + + toXContentFragmentOfExposedFields(builder, params); + rateLimitSettings.toXContent(builder, params); + + builder.endObject(); + return builder; + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return TransportVersions.ML_INFERENCE_GOOGLE_AI_STUDIO_EMBEDDINGS_ADDED; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(modelId); + out.writeOptionalVInt(maxInputTokens); + out.writeOptionalVInt(dims); + out.writeOptionalEnum(similarity); + rateLimitSettings.writeTo(out); + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_ID, modelId); + + if (maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, maxInputTokens); + } + + if (dims != null) { + builder.field(DIMENSIONS, dims); + } + + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + + return builder; + } + + @Override + public boolean equals(Object object) { + if (this == object) return true; + if (object == null || getClass() != object.getClass()) return false; + GoogleAiStudioEmbeddingsServiceSettings that = (GoogleAiStudioEmbeddingsServiceSettings) object; + return Objects.equals(modelId, that.modelId) + && Objects.equals(rateLimitSettings, that.rateLimitSettings) + && Objects.equals(dims, that.dims) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && similarity == that.similarity; + } + + @Override + public int hashCode() { + return Objects.hash(modelId, rateLimitSettings, dims, maxInputTokens, similarity); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModel.java index 8a947ce9a024b..9010571ea2e55 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModel.java @@ -14,6 +14,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionVisitor; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceModel; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import java.util.Map; @@ -30,7 +31,7 @@ public HuggingFaceElserModel( taskType, service, HuggingFaceElserServiceSettings.fromMap(serviceSettings), - HuggingFaceElserSecretSettings.fromMap(secrets) + DefaultSecretSettings.fromMap(secrets) ); } @@ -39,7 +40,7 @@ public HuggingFaceElserModel( TaskType taskType, String service, HuggingFaceElserServiceSettings serviceSettings, - @Nullable HuggingFaceElserSecretSettings secretSettings + @Nullable DefaultSecretSettings secretSettings ) { super( new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings), @@ -55,8 +56,8 @@ public HuggingFaceElserServiceSettings getServiceSettings() { } @Override - public HuggingFaceElserSecretSettings getSecretSettings() { - return (HuggingFaceElserSecretSettings) super.getSecretSettings(); + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); } @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserSecretSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserSecretSettings.java deleted file mode 100644 index 48c8997f2a1bd..0000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserSecretSettings.java +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.huggingface.elser; - -import org.elasticsearch.TransportVersion; -import org.elasticsearch.TransportVersions; -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.common.io.stream.StreamInput; -import org.elasticsearch.common.io.stream.StreamOutput; -import org.elasticsearch.common.settings.SecureString; -import org.elasticsearch.core.Nullable; -import org.elasticsearch.inference.ModelSecrets; -import org.elasticsearch.inference.SecretSettings; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xpack.inference.services.settings.ApiKeySecrets; - -import java.io.IOException; -import java.util.Map; -import java.util.Objects; - -import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredSecureString; - -public record HuggingFaceElserSecretSettings(SecureString apiKey) implements SecretSettings, ApiKeySecrets { - public static final String NAME = "hugging_face_elser_secret_settings"; - - static final String API_KEY = "api_key"; - - public static HuggingFaceElserSecretSettings fromMap(@Nullable Map map) { - if (map == null) { - return null; - } - - ValidationException validationException = new ValidationException(); - SecureString secureApiToken = extractRequiredSecureString(map, API_KEY, ModelSecrets.SECRET_SETTINGS, validationException); - - if (validationException.validationErrors().isEmpty() == false) { - throw validationException; - } - - return new HuggingFaceElserSecretSettings(secureApiToken); - } - - public HuggingFaceElserSecretSettings { - Objects.requireNonNull(apiKey); - } - - public HuggingFaceElserSecretSettings(StreamInput in) throws IOException { - this(in.readSecureString()); - } - - @Override - public String getWriteableName() { - return NAME; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.field(API_KEY, apiKey.toString()); - builder.endObject(); - return builder; - } - - @Override - public TransportVersion getMinimalSupportedVersion() { - return TransportVersions.V_8_12_0; - } - - @Override - public void writeTo(StreamOutput out) throws IOException { - out.writeSecureString(apiKey); - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralConstants.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralConstants.java new file mode 100644 index 0000000000000..d059545ca1ea3 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralConstants.java @@ -0,0 +1,21 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral; + +public class MistralConstants { + public static final String API_EMBEDDINGS_PATH = "https://api.mistral.ai/v1/embeddings"; + + // note - there is no bounds information available from Mistral, + // so we'll use a sane default here which is the same as Cohere's + public static final int MAX_BATCH_SIZE = 96; + + public static final String API_KEY_FIELD = "api_key"; + public static final String MODEL_FIELD = "model"; + public static final String INPUT_FIELD = "input"; + public static final String ENCODING_FORMAT_FIELD = "encoding_format"; +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java new file mode 100644 index 0000000000000..7ddb71d001e8c --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -0,0 +1,272 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.TransportVersion; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingResults; +import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; +import org.elasticsearch.xpack.inference.common.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.mistral.MistralActionCreator; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.SenderService; +import org.elasticsearch.xpack.inference.services.ServiceComponents; +import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; + +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.elasticsearch.TransportVersions.ADD_MISTRAL_EMBEDDINGS_INFERENCE; +import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; + +public class MistralService extends SenderService { + public static final String NAME = "mistral"; + + public MistralService(HttpRequestSender.Factory factory, ServiceComponents serviceComponents) { + super(factory, serviceComponents); + } + + @Override + protected void doInfer( + Model model, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + var actionCreator = new MistralActionCreator(getSender(), getServiceComponents()); + + if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) { + var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings); + action.execute(new DocumentsOnlyInput(input), timeout, listener); + } else { + listener.onFailure(createInvalidModelException(model)); + } + } + + @Override + protected void doInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + TimeValue timeout, + ActionListener listener + ) { + throw new UnsupportedOperationException("Mistral service does not support inference with query input"); + } + + @Override + protected void doChunkedInfer( + Model model, + String query, + List input, + Map taskSettings, + InputType inputType, + ChunkingOptions chunkingOptions, + TimeValue timeout, + ActionListener> listener + ) { + var actionCreator = new MistralActionCreator(getSender(), getServiceComponents()); + + if (model instanceof MistralEmbeddingsModel mistralEmbeddingsModel) { + var batchedRequests = new EmbeddingRequestChunker(input, MistralConstants.MAX_BATCH_SIZE).batchRequestsWithListeners(listener); + + for (var request : batchedRequests) { + var action = mistralEmbeddingsModel.accept(actionCreator, taskSettings); + action.execute(new DocumentsOnlyInput(request.batch().inputs()), timeout, request.listener()); + } + } else { + listener.onFailure(createInvalidModelException(model)); + } + } + + private static List translateToChunkedResults( + List inputs, + InferenceServiceResults inferenceResults + ) { + if (inferenceResults instanceof TextEmbeddingResults textEmbeddingResults) { + return ChunkedTextEmbeddingResults.of(inputs, textEmbeddingResults); + } else if (inferenceResults instanceof ErrorInferenceResults error) { + return List.of(new ErrorChunkedInferenceResults(error.getException())); + } else { + throw createInvalidChunkedResultException(inferenceResults.getWriteableName()); + } + } + + @Override + public String name() { + return NAME; + } + + @Override + public void parseRequestConfig( + String modelId, + TaskType taskType, + Map config, + Set platfromArchitectures, + ActionListener parsedModelListener + ) { + try { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + MistralEmbeddingsModel model = createModel( + modelId, + taskType, + serviceSettingsMap, + taskSettingsMap, + serviceSettingsMap, + TaskType.unsupportedTaskTypeErrorMsg(taskType, NAME), + ConfigurationParseContext.REQUEST + ); + + throwIfNotEmptyMap(config, NAME); + throwIfNotEmptyMap(serviceSettingsMap, NAME); + throwIfNotEmptyMap(taskSettingsMap, NAME); + + parsedModelListener.onResponse(model); + } catch (Exception e) { + parsedModelListener.onFailure(e); + } + } + + @Override + public Model parsePersistedConfigWithSecrets( + String modelId, + TaskType taskType, + Map config, + Map secrets + ) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + Map secretSettingsMap = removeFromMapOrDefaultEmpty(secrets, ModelSecrets.SECRET_SETTINGS); + + return createModelFromPersistent( + modelId, + taskType, + serviceSettingsMap, + taskSettingsMap, + secretSettingsMap, + parsePersistedConfigErrorMsg(modelId, NAME) + ); + } + + @Override + public Model parsePersistedConfig(String modelId, TaskType taskType, Map config) { + Map serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS); + Map taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS); + + return createModelFromPersistent( + modelId, + taskType, + serviceSettingsMap, + taskSettingsMap, + null, + parsePersistedConfigErrorMsg(modelId, NAME) + ); + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return ADD_MISTRAL_EMBEDDINGS_INFERENCE; + } + + private static MistralEmbeddingsModel createModel( + String modelId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + @Nullable Map secretSettings, + String failureMessage, + ConfigurationParseContext context + ) { + if (taskType == TaskType.TEXT_EMBEDDING) { + return new MistralEmbeddingsModel(modelId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context); + } + + throw new ElasticsearchStatusException(failureMessage, RestStatus.BAD_REQUEST); + } + + private MistralEmbeddingsModel createModelFromPersistent( + String inferenceEntityId, + TaskType taskType, + Map serviceSettings, + Map taskSettings, + Map secretSettings, + String failureMessage + ) { + return createModel( + inferenceEntityId, + taskType, + serviceSettings, + taskSettings, + secretSettings, + failureMessage, + ConfigurationParseContext.PERSISTENT + ); + } + + @Override + public void checkModelConfig(Model model, ActionListener listener) { + if (model instanceof MistralEmbeddingsModel embeddingsModel) { + ServiceUtils.getEmbeddingSize( + model, + this, + listener.delegateFailureAndWrap((l, size) -> l.onResponse(updateEmbeddingModelConfig(embeddingsModel, size))) + ); + } else { + listener.onResponse(model); + } + } + + private MistralEmbeddingsModel updateEmbeddingModelConfig(MistralEmbeddingsModel embeddingsModel, int embeddingsSize) { + var embeddingServiceSettings = embeddingsModel.getServiceSettings(); + + var similarityFromModel = embeddingsModel.getServiceSettings().similarity(); + var similarityToUse = similarityFromModel == null ? SimilarityMeasure.DOT_PRODUCT : similarityFromModel; + + MistralEmbeddingsServiceSettings serviceSettings = new MistralEmbeddingsServiceSettings( + embeddingServiceSettings.model(), + embeddingsSize, + embeddingServiceSettings.maxInputTokens(), + similarityToUse, + embeddingServiceSettings.rateLimitSettings() + ); + + return new MistralEmbeddingsModel(embeddingsModel, serviceSettings); + } + +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java new file mode 100644 index 0000000000000..c3d261efea79a --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsModel.java @@ -0,0 +1,134 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral.embeddings; + +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xpack.inference.external.action.ExecutableAction; +import org.elasticsearch.xpack.inference.external.action.mistral.MistralActionVisitor; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.API_EMBEDDINGS_PATH; + +public class MistralEmbeddingsModel extends Model { + protected String model; + protected URI uri; + protected RateLimitSettings rateLimitSettings; + + public MistralEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + Map serviceSettings, + Map taskSettings, + @Nullable Map secrets, + ConfigurationParseContext context + ) { + this( + inferenceEntityId, + taskType, + service, + MistralEmbeddingsServiceSettings.fromMap(serviceSettings, context), + EmptyTaskSettings.INSTANCE, // no task settings for Mistral embeddings + DefaultSecretSettings.fromMap(secrets) + ); + } + + public MistralEmbeddingsModel(MistralEmbeddingsModel model, TaskSettings taskSettings, RateLimitSettings rateLimitSettings) { + super(model, taskSettings); + this.model = Objects.requireNonNull(model.model); + this.rateLimitSettings = Objects.requireNonNull(rateLimitSettings); + setEndpointUrl(); + } + + public MistralEmbeddingsModel(MistralEmbeddingsModel model, MistralEmbeddingsServiceSettings serviceSettings) { + super(model, serviceSettings); + setPropertiesFromServiceSettings(serviceSettings); + } + + protected MistralEmbeddingsModel(ModelConfigurations modelConfigurations, ModelSecrets modelSecrets) { + super(modelConfigurations, modelSecrets); + setPropertiesFromServiceSettings((MistralEmbeddingsServiceSettings) modelConfigurations.getServiceSettings()); + } + + public MistralEmbeddingsModel( + String inferenceEntityId, + TaskType taskType, + String service, + MistralEmbeddingsServiceSettings serviceSettings, + TaskSettings taskSettings, + DefaultSecretSettings secrets + ) { + super( + new ModelConfigurations(inferenceEntityId, taskType, service, serviceSettings, new EmptyTaskSettings()), + new ModelSecrets(secrets) + ); + setPropertiesFromServiceSettings(serviceSettings); + } + + private void setPropertiesFromServiceSettings(MistralEmbeddingsServiceSettings serviceSettings) { + this.model = serviceSettings.model(); + this.rateLimitSettings = serviceSettings.rateLimitSettings(); + setEndpointUrl(); + } + + @Override + public MistralEmbeddingsServiceSettings getServiceSettings() { + return (MistralEmbeddingsServiceSettings) super.getServiceSettings(); + } + + public String model() { + return this.model; + } + + public URI uri() { + return this.uri; + } + + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + private void setEndpointUrl() { + try { + this.uri = new URI(API_EMBEDDINGS_PATH); + } catch (URISyntaxException e) { + throw new RuntimeException(e); + } + } + + // Needed for testing only + public void setURI(String newUri) { + try { + this.uri = new URI(newUri); + } catch (URISyntaxException e) { + // swallow any error + } + } + + @Override + public DefaultSecretSettings getSecretSettings() { + return (DefaultSecretSettings) super.getSecretSettings(); + } + + public ExecutableAction accept(MistralActionVisitor creator, Map taskSettings) { + return creator.create(this, taskSettings); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java new file mode 100644 index 0000000000000..d2ea8ccbd18bd --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettings.java @@ -0,0 +1,182 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral.embeddings; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.ValidationException; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ServiceSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.settings.FilteredXContentObject; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +import java.io.IOException; +import java.util.Map; +import java.util.Objects; + +import static org.elasticsearch.TransportVersions.ADD_MISTRAL_EMBEDDINGS_INFERENCE; +import static org.elasticsearch.xpack.inference.services.ServiceFields.DIMENSIONS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS; +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeAsType; +import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD; + +public class MistralEmbeddingsServiceSettings extends FilteredXContentObject implements ServiceSettings { + public static final String NAME = "mistral_embeddings_service_settings"; + + private final String model; + private final Integer dimensions; + private final SimilarityMeasure similarity; + private final Integer maxInputTokens; + private final RateLimitSettings rateLimitSettings; + + // default for Mistral is 5 requests / sec + // setting this to 240 (4 requests / sec) is a sane default for us + protected static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(240); + + public static MistralEmbeddingsServiceSettings fromMap(Map map, ConfigurationParseContext context) { + ValidationException validationException = new ValidationException(); + + String model = extractRequiredString(map, MODEL_FIELD, ModelConfigurations.SERVICE_SETTINGS, validationException); + SimilarityMeasure similarity = extractSimilarity(map, ModelConfigurations.SERVICE_SETTINGS, validationException); + Integer maxInputTokens = extractOptionalPositiveInteger( + map, + MAX_INPUT_TOKENS, + ModelConfigurations.SERVICE_SETTINGS, + validationException + ); + RateLimitSettings rateLimitSettings = RateLimitSettings.of(map, DEFAULT_RATE_LIMIT_SETTINGS, validationException); + Integer dims = removeAsType(map, DIMENSIONS, Integer.class); + + if (validationException.validationErrors().isEmpty() == false) { + throw validationException; + } + + return new MistralEmbeddingsServiceSettings(model, dims, maxInputTokens, similarity, rateLimitSettings); + } + + public MistralEmbeddingsServiceSettings(StreamInput in) throws IOException { + this.model = in.readString(); + this.dimensions = in.readOptionalVInt(); + this.similarity = in.readOptionalEnum(SimilarityMeasure.class); + this.maxInputTokens = in.readOptionalVInt(); + this.rateLimitSettings = new RateLimitSettings(in); + } + + public MistralEmbeddingsServiceSettings( + String model, + @Nullable Integer dimensions, + @Nullable Integer maxInputTokens, + @Nullable SimilarityMeasure similarity, + @Nullable RateLimitSettings rateLimitSettings + ) { + this.model = model; + this.dimensions = dimensions; + this.similarity = similarity; + this.maxInputTokens = maxInputTokens; + this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS); + } + + @Override + public String getWriteableName() { + return NAME; + } + + @Override + public TransportVersion getMinimalSupportedVersion() { + return ADD_MISTRAL_EMBEDDINGS_INFERENCE; + } + + public String model() { + return this.model; + } + + @Override + public Integer dimensions() { + return this.dimensions; + } + + public Integer maxInputTokens() { + return this.maxInputTokens; + } + + @Override + public SimilarityMeasure similarity() { + return this.similarity; + } + + @Override + public DenseVectorFieldMapper.ElementType elementType() { + return DenseVectorFieldMapper.ElementType.FLOAT; + } + + public RateLimitSettings rateLimitSettings() { + return this.rateLimitSettings; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(model); + out.writeOptionalVInt(dimensions); + out.writeOptionalEnum(SimilarityMeasure.translateSimilarity(similarity, out.getTransportVersion())); + out.writeOptionalVInt(maxInputTokens); + rateLimitSettings.writeTo(out); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + this.toXContentFragmentOfExposedFields(builder, params); + rateLimitSettings.toXContent(builder, params); + builder.endObject(); + return builder; + } + + @Override + protected XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder, Params params) throws IOException { + builder.field(MODEL_FIELD, this.model); + + if (dimensions != null) { + builder.field(DIMENSIONS, dimensions); + } + if (similarity != null) { + builder.field(SIMILARITY, similarity); + } + if (this.maxInputTokens != null) { + builder.field(MAX_INPUT_TOKENS, this.maxInputTokens); + } + + return builder; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + MistralEmbeddingsServiceSettings that = (MistralEmbeddingsServiceSettings) o; + return Objects.equals(model, that.model) + && Objects.equals(dimensions, that.dimensions) + && Objects.equals(maxInputTokens, that.maxInputTokens) + && Objects.equals(similarity, that.similarity); + } + + @Override + public int hashCode() { + return Objects.hash(model, dimensions, maxInputTokens, similarity); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index e6cfd565c2a17..58603526a9c56 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -36,6 +36,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; +import static org.elasticsearch.test.ESTestCase.randomFrom; import static org.elasticsearch.xpack.inference.InferencePlugin.UTILITY_THREAD_POOL_NAME; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.MatcherAssert.assertThat; @@ -153,4 +154,9 @@ public static Model getInvalidModel(String inferenceEntityId, String serviceName return mockModel; } + + public static SimilarityMeasure randomSimilarityMeasure() { + return randomFrom(SimilarityMeasure.values()); + } + } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java new file mode 100644 index 0000000000000..a55b3c5f5030c --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioEmbeddingsActionTests.java @@ -0,0 +1,191 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.action.googleaistudio; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Strings; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.util.List; +import java.util.Map; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsModelTests.createModel; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.endsWith; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.is; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; + +public class GoogleAiStudioEmbeddingsActionTests extends ESTestCase { + + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testExecute_ReturnsSuccessfulResponse() throws IOException { + var apiKey = "apiKey"; + var model = "model"; + var input = "input"; + var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); + + try (var sender = senderFactory.createSender("test_service")) { + sender.start(); + + String responseJson = """ + { + "embeddings": [ + { + "values": [ + 0.0123, + -0.0123 + ] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var action = createAction(getUrl(webServer), apiKey, model, sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of(input)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); + assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests().get(0).getUri().getQuery(), endsWith(apiKey)); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap, aMapWithSize(1)); + assertThat( + requestMap.get("requests"), + is( + List.of( + Map.of( + "model", + Strings.format("%s/%s", "models", model), + "content", + Map.of("parts", List.of(Map.of("text", input))) + ) + ) + ) + ); + } + } + + public void testExecute_ThrowsElasticsearchException() { + var sender = mock(Sender.class); + doThrow(new ElasticsearchException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "api_key", "model", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat(thrownException.getMessage(), is("failed")); + } + + public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled() { + var sender = mock(Sender.class); + + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + ActionListener listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new IllegalStateException("failed")); + + return Void.TYPE; + }).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "api_key", "model", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + thrownException.getMessage(), + is(format("Failed to send Google AI Studio embeddings request to [%s]", getUrl(webServer))) + ); + } + + public void testExecute_ThrowsException() { + var sender = mock(Sender.class); + doThrow(new IllegalArgumentException("failed")).when(sender).send(any(), any(), any(), any()); + + var action = createAction(getUrl(webServer), "api_key", "model", sender); + + PlainActionFuture listener = new PlainActionFuture<>(); + action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + + var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + thrownException.getMessage(), + is(format("Failed to send Google AI Studio embeddings request to [%s]", getUrl(webServer))) + ); + } + + private GoogleAiStudioEmbeddingsAction createAction(String url, String apiKey, String modelName, Sender sender) { + var model = createModel(modelName, apiKey, url); + + return new GoogleAiStudioEmbeddingsAction(sender, model, createWithEmptySettings(threadPool)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequestTests.java index d77c88dacd06f..da6070f1f455f 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioRequestTests.java @@ -12,7 +12,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Strings; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import java.net.URI; import java.net.URISyntaxException; @@ -28,7 +28,7 @@ public void testDecorateWithApiKeyParameter() throws URISyntaxException { var uriString = "https://localhost:3000"; var secureApiKey = new SecureString("api_key".toCharArray()); var httpPost = new HttpPost(uriString); - var secretSettings = new GoogleAiStudioSecretSettings(secureApiKey); + var secretSettings = new DefaultSecretSettings(secureApiKey); GoogleAiStudioRequest.decorateWithApiKeyParameter(httpPost, secretSettings); @@ -45,7 +45,7 @@ public void testDecorateWithApiKeyParameter_ThrowsValidationException_WhenAnyExc ValidationException.class, () -> GoogleAiStudioRequest.decorateWithApiKeyParameter( httpPost, - new GoogleAiStudioSecretSettings(new SecureString("abc".toCharArray())) + new DefaultSecretSettings(new SecureString("abc".toCharArray())) ) ); assertThat(validationException.getCause(), is(cause)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/embeddings/GoogleAiStudioEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/embeddings/GoogleAiStudioEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..4c3b33e1dc950 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/embeddings/GoogleAiStudioEmbeddingsRequestEntityTests.java @@ -0,0 +1,146 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioEmbeddingsRequestEntity; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; + +public class GoogleAiStudioEmbeddingsRequestEntityTests extends ESTestCase { + + public void testXContent_SingleRequest_WritesDimensionsIfDefined() throws IOException { + var entity = new GoogleAiStudioEmbeddingsRequestEntity(List.of("abc"), "model", 8); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "requests": [ + { + "model": "models/model", + "content": { + "parts": [ + { + "text": "abc" + } + ] + }, + "outputDimensionality": 8 + } + ] + } + """)); + } + + public void testXContent_SingleRequest_DoesNotWriteDimensionsIfNull() throws IOException { + var entity = new GoogleAiStudioEmbeddingsRequestEntity(List.of("abc"), "model", null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "requests": [ + { + "model": "models/model", + "content": { + "parts": [ + { + "text": "abc" + } + ] + } + } + ] + } + """)); + } + + public void testXContent_MultipleRequests_WritesDimensionsIfDefined() throws IOException { + var entity = new GoogleAiStudioEmbeddingsRequestEntity(List.of("abc", "def"), "model", 8); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "requests": [ + { + "model": "models/model", + "content": { + "parts": [ + { + "text": "abc" + } + ] + }, + "outputDimensionality": 8 + }, + { + "model": "models/model", + "content": { + "parts": [ + { + "text": "def" + } + ] + }, + "outputDimensionality": 8 + } + ] + } + """)); + } + + public void testXContent_MultipleRequests_DoesNotWriteDimensionsIfNull() throws IOException { + var entity = new GoogleAiStudioEmbeddingsRequestEntity(List.of("abc", "def"), "model", null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "requests": [ + { + "model": "models/model", + "content": { + "parts": [ + { + "text": "abc" + } + ] + } + }, + { + "model": "models/model", + "content": { + "parts": [ + { + "text": "def" + } + ] + } + } + ] + } + """)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/embeddings/GoogleAiStudioEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/embeddings/GoogleAiStudioEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..9ce254bd8e3da --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/embeddings/GoogleAiStudioEmbeddingsRequestTests.java @@ -0,0 +1,152 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.googleaistudio.embeddings; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioEmbeddingsRequest; +import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsModelTests; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.endsWith; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class GoogleAiStudioEmbeddingsRequestTests extends ESTestCase { + + public void testCreateRequest_WithoutDimensionsSet() throws IOException { + var model = "model"; + var apiKey = "api_key"; + var input = "input"; + + var request = createRequest(model, apiKey, input, null, null); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), endsWith(Strings.format("%s=%s", "key", apiKey))); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(1)); + assertThat( + requestMap.get("requests"), + is( + List.of( + Map.of("model", Strings.format("%s/%s", "models", model), "content", Map.of("parts", List.of(Map.of("text", input)))) + ) + ) + ); + } + + public void testCreateRequest_WithDimensionsSet() throws IOException { + var model = "model"; + var apiKey = "api_key"; + var input = "input"; + var dimensions = 8; + + var request = createRequest(model, apiKey, input, null, dimensions); + var httpRequest = request.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), endsWith(Strings.format("%s=%s", "key", apiKey))); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(1)); + assertThat( + requestMap.get("requests"), + is( + List.of( + Map.of( + "model", + Strings.format("%s/%s", "models", model), + "content", + Map.of("parts", List.of(Map.of("text", input))), + "outputDimensionality", + dimensions + ) + ) + ) + ); + } + + public void testTruncate_ReducesInputTextSizeByHalf() throws IOException { + var model = "model"; + var apiKey = "api_key"; + var input = "abcd"; + var dimensions = 8; + + var request = createRequest(model, apiKey, input, null, dimensions); + var truncatedRequest = request.truncate(); + var httpRequest = truncatedRequest.createHttpRequest(); + + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + + assertThat(httpPost.getURI().toString(), endsWith(Strings.format("%s=%s", "key", apiKey))); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(1)); + assertThat( + requestMap.get("requests"), + is( + List.of( + Map.of( + "model", + Strings.format("%s/%s", "models", model), + "content", + // "abcd" reduced by half -> "ab" + Map.of("parts", List.of(Map.of("text", "ab"))), + "outputDimensionality", + dimensions + ) + ) + ) + ); + } + + public void testIsTruncated_ReturnsTrue() { + var request = createRequest("model", "api key", "input", null, null); + assertFalse(request.getTruncationInfo()[0]); + + var truncatedRequest = request.truncate(); + assertTrue(truncatedRequest.getTruncationInfo()[0]); + } + + public static GoogleAiStudioEmbeddingsRequest createRequest( + String model, + String apiKey, + String input, + @Nullable Integer maxTokens, + @Nullable Integer dimensions + ) { + var embeddingsModel = GoogleAiStudioEmbeddingsModelTests.createModel(model, apiKey, maxTokens, dimensions); + + return new GoogleAiStudioEmbeddingsRequest( + TruncatorTests.createTruncator(), + new Truncator.TruncationResult(List.of(input), new boolean[] { false }), + embeddingsModel + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/mistral/MistralEmbeddingsRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/mistral/MistralEmbeddingsRequestEntityTests.java new file mode 100644 index 0000000000000..181ca3d5145b8 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/mistral/MistralEmbeddingsRequestEntityTests.java @@ -0,0 +1,32 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.mistral; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.CoreMatchers.is; + +public class MistralEmbeddingsRequestEntityTests extends ESTestCase { + public void testXContent_WritesModelInputAndFormat() throws IOException { + var entity = new MistralEmbeddingsRequestEntity("mistral-embed", List.of("abc")); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, is(""" + {"model":"mistral-embed","input":["abc"],"encoding_format":"float"}""")); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/mistral/MistralEmbeddingsRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/mistral/MistralEmbeddingsRequestTests.java new file mode 100644 index 0000000000000..8f78c70da0c61 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/mistral/MistralEmbeddingsRequestTests.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.mistral; + +import org.apache.http.HttpHeaders; +import org.apache.http.client.methods.HttpPost; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.common.Truncator; +import org.elasticsearch.xpack.inference.common.TruncatorTests; +import org.elasticsearch.xpack.inference.external.request.HttpRequest; +import org.elasticsearch.xpack.inference.services.mistral.MistralConstants; +import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingModelTests; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.hamcrest.Matchers.aMapWithSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class MistralEmbeddingsRequestTests extends ESTestCase { + public void testCreateRequest_Works() throws IOException { + var request = createRequest("mistral-embed", "apikey", "abcd"); + var httpRequest = request.createHttpRequest(); + var httpPost = validateRequestUrlAndContentType(httpRequest, MistralConstants.API_EMBEDDINGS_PATH); + assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer apikey")); + + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get("input"), is(List.of("abcd"))); + assertThat(requestMap.get("model"), is("mistral-embed")); + assertThat(requestMap.get("encoding_format"), is("float")); + } + + public void testTruncate_ReducesInputTextSizeByHalf() throws IOException { + var request = createRequest("mistral-embed", "apikey", "abcd"); + var truncatedRequest = request.truncate(); + + var httpRequest = truncatedRequest.createHttpRequest(); + assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); + + var httpPost = (HttpPost) httpRequest.httpRequestBase(); + var requestMap = entityAsMap(httpPost.getEntity().getContent()); + assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap.get("input"), is(List.of("ab"))); + assertThat(requestMap.get("model"), is("mistral-embed")); + assertThat(requestMap.get("encoding_format"), is("float")); + } + + public void testIsTruncated_ReturnsTrue() { + var request = createRequest("mistral-embed", "apikey", "abcd"); + assertFalse(request.getTruncationInfo()[0]); + + var truncatedRequest = request.truncate(); + assertTrue(truncatedRequest.getTruncationInfo()[0]); + } + + private HttpPost validateRequestUrlAndContentType(HttpRequest request, String expectedUrl) throws IOException { + assertThat(request.httpRequestBase(), instanceOf(HttpPost.class)); + var httpPost = (HttpPost) request.httpRequestBase(); + assertThat(httpPost.getURI().toString(), is(expectedUrl)); + assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); + return httpPost; + } + + public static MistralEmbeddingsRequest createRequest(String model, String apiKey, String input) { + var embeddingsModel = MistralEmbeddingModelTests.createModel("id", model, apiKey, null, null, null, null); + return new MistralEmbeddingsRequest( + TruncatorTests.createTruncator(), + new Truncator.TruncationResult(List.of(input), new boolean[] { false }), + embeddingsModel + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiErrorResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiErrorResponseEntityTests.java index fd133a26f5532..48a560341f392 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiErrorResponseEntityTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiErrorResponseEntityTests.java @@ -26,7 +26,7 @@ public void testErrorResponse_ExtractsError() { var result = getMockResult(""" {"error":{"message":"test_error_message"}}"""); - var error = AzureAndOpenAiErrorResponseEntity.fromResponse(result); + var error = AzureMistralOpenAiErrorResponseEntity.fromResponse(result); assertNotNull(error); assertThat(error.getErrorMessage(), is("test_error_message")); } @@ -35,14 +35,14 @@ public void testErrorResponse_ReturnsNullIfNoError() { var result = getMockResult(""" {"noerror":true}"""); - var error = AzureAndOpenAiErrorResponseEntity.fromResponse(result); + var error = AzureMistralOpenAiErrorResponseEntity.fromResponse(result); assertNull(error); } public void testErrorResponse_ReturnsNullIfNotJson() { var result = getMockResult("not a json string"); - var error = AzureAndOpenAiErrorResponseEntity.fromResponse(result); + var error = AzureMistralOpenAiErrorResponseEntity.fromResponse(result); assertNull(error); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandlerTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandlerTests.java index 4c9fb143c3a5c..9ef9ab4daa0ae 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandlerTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/AzureAndOpenAiExternalResponseHandlerTests.java @@ -42,10 +42,10 @@ public void testCheckForFailureStatusCode() { var mockRequest = RequestTests.mockRequest("id"); var httpResult = new HttpResult(httpResponse, new byte[] {}); - var handler = new AzureAndOpenAiExternalResponseHandler( + var handler = new AzureMistralOpenAiExternalResponseHandler( "", (request, result) -> null, - AzureAndOpenAiErrorResponseEntity::fromResponse + AzureMistralOpenAiErrorResponseEntity::fromResponse ); // 200 ok @@ -157,20 +157,20 @@ public void testBuildRateLimitErrorMessage() { var httpResult = new HttpResult(response, new byte[] {}); { - when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REQUESTS_LIMIT)).thenReturn( - new BasicHeader(AzureAndOpenAiExternalResponseHandler.REQUESTS_LIMIT, "3000") + when(response.getFirstHeader(AzureMistralOpenAiExternalResponseHandler.REQUESTS_LIMIT)).thenReturn( + new BasicHeader(AzureMistralOpenAiExternalResponseHandler.REQUESTS_LIMIT, "3000") ); - when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_REQUESTS)).thenReturn( - new BasicHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_REQUESTS, "2999") + when(response.getFirstHeader(AzureMistralOpenAiExternalResponseHandler.REMAINING_REQUESTS)).thenReturn( + new BasicHeader(AzureMistralOpenAiExternalResponseHandler.REMAINING_REQUESTS, "2999") ); - when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.TOKENS_LIMIT)).thenReturn( - new BasicHeader(AzureAndOpenAiExternalResponseHandler.TOKENS_LIMIT, "10000") + when(response.getFirstHeader(AzureMistralOpenAiExternalResponseHandler.TOKENS_LIMIT)).thenReturn( + new BasicHeader(AzureMistralOpenAiExternalResponseHandler.TOKENS_LIMIT, "10000") ); - when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_TOKENS)).thenReturn( - new BasicHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_TOKENS, "99800") + when(response.getFirstHeader(AzureMistralOpenAiExternalResponseHandler.REMAINING_TOKENS)).thenReturn( + new BasicHeader(AzureMistralOpenAiExternalResponseHandler.REMAINING_TOKENS, "99800") ); - var error = AzureAndOpenAiExternalResponseHandler.buildRateLimitErrorMessage(httpResult); + var error = AzureMistralOpenAiExternalResponseHandler.buildRateLimitErrorMessage(httpResult); assertThat( error, containsString("Token limit [10000], remaining tokens [99800]. Request limit [3000], remaining requests [2999]") @@ -178,9 +178,9 @@ public void testBuildRateLimitErrorMessage() { } { - when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.TOKENS_LIMIT)).thenReturn(null); - when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_TOKENS)).thenReturn(null); - var error = AzureAndOpenAiExternalResponseHandler.buildRateLimitErrorMessage(httpResult); + when(response.getFirstHeader(AzureMistralOpenAiExternalResponseHandler.TOKENS_LIMIT)).thenReturn(null); + when(response.getFirstHeader(AzureMistralOpenAiExternalResponseHandler.REMAINING_TOKENS)).thenReturn(null); + var error = AzureMistralOpenAiExternalResponseHandler.buildRateLimitErrorMessage(httpResult); assertThat( error, containsString("Token limit [unknown], remaining tokens [unknown]. Request limit [3000], remaining requests [2999]") @@ -188,26 +188,26 @@ public void testBuildRateLimitErrorMessage() { } { - when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REQUESTS_LIMIT)).thenReturn(null); - when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_REQUESTS)).thenReturn( - new BasicHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_REQUESTS, "2999") + when(response.getFirstHeader(AzureMistralOpenAiExternalResponseHandler.REQUESTS_LIMIT)).thenReturn(null); + when(response.getFirstHeader(AzureMistralOpenAiExternalResponseHandler.REMAINING_REQUESTS)).thenReturn( + new BasicHeader(AzureMistralOpenAiExternalResponseHandler.REMAINING_REQUESTS, "2999") ); - when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.TOKENS_LIMIT)).thenReturn(null); - when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_TOKENS)).thenReturn(null); - var error = AzureAndOpenAiExternalResponseHandler.buildRateLimitErrorMessage(httpResult); + when(response.getFirstHeader(AzureMistralOpenAiExternalResponseHandler.TOKENS_LIMIT)).thenReturn(null); + when(response.getFirstHeader(AzureMistralOpenAiExternalResponseHandler.REMAINING_TOKENS)).thenReturn(null); + var error = AzureMistralOpenAiExternalResponseHandler.buildRateLimitErrorMessage(httpResult); assertThat(error, containsString("Remaining tokens [unknown]. Remaining requests [2999]")); } { - when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REQUESTS_LIMIT)).thenReturn(null); - when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_REQUESTS)).thenReturn( - new BasicHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_REQUESTS, "2999") + when(response.getFirstHeader(AzureMistralOpenAiExternalResponseHandler.REQUESTS_LIMIT)).thenReturn(null); + when(response.getFirstHeader(AzureMistralOpenAiExternalResponseHandler.REMAINING_REQUESTS)).thenReturn( + new BasicHeader(AzureMistralOpenAiExternalResponseHandler.REMAINING_REQUESTS, "2999") ); - when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.TOKENS_LIMIT)).thenReturn( - new BasicHeader(AzureAndOpenAiExternalResponseHandler.TOKENS_LIMIT, "10000") + when(response.getFirstHeader(AzureMistralOpenAiExternalResponseHandler.TOKENS_LIMIT)).thenReturn( + new BasicHeader(AzureMistralOpenAiExternalResponseHandler.TOKENS_LIMIT, "10000") ); - when(response.getFirstHeader(AzureAndOpenAiExternalResponseHandler.REMAINING_TOKENS)).thenReturn(null); - var error = AzureAndOpenAiExternalResponseHandler.buildRateLimitErrorMessage(httpResult); + when(response.getFirstHeader(AzureMistralOpenAiExternalResponseHandler.REMAINING_TOKENS)).thenReturn(null); + var error = AzureMistralOpenAiExternalResponseHandler.buildRateLimitErrorMessage(httpResult); assertThat( error, containsString("Token limit [10000], remaining tokens [unknown]. Request limit [unknown], remaining requests [2999]") diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioEmbeddingsResponseEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioEmbeddingsResponseEntityTests.java new file mode 100644 index 0000000000000..5d5096d0b1b51 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/response/googleaistudio/GoogleAiStudioEmbeddingsResponseEntityTests.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.response.googleaistudio; + +import org.apache.http.HttpResponse; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.core.inference.results.TextEmbeddingResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.request.Request; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.List; + +import static org.hamcrest.Matchers.is; +import static org.mockito.Mockito.mock; + +public class GoogleAiStudioEmbeddingsResponseEntityTests extends ESTestCase { + + public void testFromResponse_CreatesResultsForASingleItem() throws IOException { + String responseJson = """ + { + "embeddings": [ + { + "values": [ + -0.00606332, + 0.058092743 + ] + } + ] + } + """; + + TextEmbeddingResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat(parsedResults.embeddings(), is(List.of(TextEmbeddingResults.Embedding.of(List.of(-0.00606332F, 0.058092743F))))); + } + + public void testFromResponse_CreatesResultsForMultipleItems() throws IOException { + String responseJson = """ + { + "embeddings": [ + { + "values": [ + -0.00606332, + 0.058092743 + ] + }, + { + "values": [ + 0.030681048, + 0.01714732 + ] + } + ] + } + """; + + TextEmbeddingResults parsedResults = GoogleAiStudioEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ); + + assertThat( + parsedResults.embeddings(), + is( + List.of( + TextEmbeddingResults.Embedding.of(List.of(-0.00606332F, 0.058092743F)), + TextEmbeddingResults.Embedding.of(List.of(0.030681048F, 0.01714732F)) + ) + ) + ); + } + + public void testFromResponse_FailsWhenEmbeddingsFieldIsNotPresent() { + String responseJson = """ + { + "not_embeddings": [ + { + "values": [ + -0.00606332, + 0.058092743 + ] + }, + { + "values": [ + 0.030681048, + 0.01714732 + ] + } + ] + } + """; + + var thrownException = expectThrows( + IllegalStateException.class, + () -> GoogleAiStudioEmbeddingsResponseEntity.fromResponse( + mock(Request.class), + new HttpResult(mock(HttpResponse.class), responseJson.getBytes(StandardCharsets.UTF_8)) + ) + ); + + assertThat(thrownException.getMessage(), is("Failed to find required field [embeddings] in Google AI Studio embeddings response")); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettingsTests.java deleted file mode 100644 index a0339934783d8..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioSecretSettingsTests.java +++ /dev/null @@ -1,71 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.googleaistudio; - -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.settings.SecureString; -import org.elasticsearch.test.AbstractWireSerializingTestCase; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.is; - -public class GoogleAiStudioSecretSettingsTests extends AbstractWireSerializingTestCase { - - public static GoogleAiStudioSecretSettings createRandom() { - return new GoogleAiStudioSecretSettings(randomSecureStringOfLength(15)); - } - - public void testFromMap() { - var apiKey = "abc"; - var secretSettings = GoogleAiStudioSecretSettings.fromMap(new HashMap<>(Map.of(GoogleAiStudioSecretSettings.API_KEY, apiKey))); - - assertThat(new GoogleAiStudioSecretSettings(new SecureString(apiKey.toCharArray())), is(secretSettings)); - } - - public void testFromMap_ReturnsNull_WhenMapIsNull() { - assertNull(GoogleAiStudioSecretSettings.fromMap(null)); - } - - public void testFromMap_ThrowsError_WhenApiKeyIsNull() { - var throwException = expectThrows(ValidationException.class, () -> GoogleAiStudioSecretSettings.fromMap(new HashMap<>())); - - assertThat(throwException.getMessage(), containsString("[secret_settings] must have [api_key] set")); - } - - public void testFromMap_ThrowsError_WhenApiKeyIsEmpty() { - var thrownException = expectThrows( - ValidationException.class, - () -> GoogleAiStudioSecretSettings.fromMap(new HashMap<>(Map.of(GoogleAiStudioSecretSettings.API_KEY, ""))) - ); - - assertThat( - thrownException.getMessage(), - containsString("[secret_settings] Invalid value empty string. [api_key] must be a non-empty string") - ); - } - - @Override - protected Writeable.Reader instanceReader() { - return GoogleAiStudioSecretSettings::new; - } - - @Override - protected GoogleAiStudioSecretSettings createTestInstance() { - return createRandom(); - } - - @Override - protected GoogleAiStudioSecretSettings mutateInstance(GoogleAiStudioSecretSettings instance) throws IOException { - return randomValueOtherThan(instance, GoogleAiStudioSecretSettingsTests::createRandom); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index f157622ea7291..32e912ff8529a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -13,13 +13,17 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; @@ -28,6 +32,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -35,12 +40,16 @@ import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; +import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModelTests; +import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.googleaistudio.embeddings.GoogleAiStudioEmbeddingsModelTests; import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; import org.junit.Before; import java.io.IOException; +import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -53,14 +62,16 @@ import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; -import static org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModelTests.createModel; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.endsWith; import static org.hamcrest.Matchers.hasSize; import static org.mockito.ArgumentMatchers.anyString; import static org.mockito.Mockito.mock; @@ -118,6 +129,33 @@ public void testParseRequestConfig_CreatesAGoogleAiStudioCompletionModel() throw } } + public void testParseRequestConfig_CreatesAGoogleAiStudioEmbeddingsModel() throws IOException { + var apiKey = "apiKey"; + var modelId = "model"; + + try (var service = createGoogleAiStudioService()) { + ActionListener modelListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey)); + }, e -> fail("Model parsing should have succeeded, but failed: " + e.getMessage())); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + new HashMap<>(Map.of()), + getSecretSettingsMap(apiKey) + ), + Set.of(), + modelListener + ); + } + } + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { try (var service = createGoogleAiStudioService()) { var failureListener = getModelListenerForException( @@ -236,6 +274,33 @@ public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioCompletion } } + public void testParsePersistedConfigWithSecrets_CreatesAGoogleAiStudioEmbeddingsModel() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = createGoogleAiStudioService()) { + var persistedConfig = getPersistedConfigMap( + new HashMap<>(Map.of(ServiceFields.MODEL_ID, modelId)), + getTaskSettingsMapEmpty(), + getSecretSettingsMap(apiKey) + ); + + var model = service.parsePersistedConfigWithSecrets( + "id", + TaskType.TEXT_EMBEDDING, + persistedConfig.config(), + persistedConfig.secrets() + ); + + assertThat(model, instanceOf(GoogleAiStudioEmbeddingsModel.class)); + + var embeddingsModel = (GoogleAiStudioEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().modelId(), is(modelId)); + assertThat(embeddingsModel.getTaskSettings(), is(EmptyTaskSettings.INSTANCE)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is(apiKey)); + } + } + public void testParsePersistedConfigWithSecrets_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { var modelId = "model"; var apiKey = "apiKey"; @@ -460,7 +525,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotGoogleAiStudioModel() throws IOEx verifyNoMoreInteractions(sender); } - public void testInfer_SendsRequest() throws IOException { + public void testInfer_SendsCompletionRequest() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { @@ -508,7 +573,7 @@ public void testInfer_SendsRequest() throws IOException { webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - var model = createModel("model", getUrl(webServer), "secret"); + var model = GoogleAiStudioCompletionModelTests.createModel("model", getUrl(webServer), "secret"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( model, @@ -541,6 +606,155 @@ public void testInfer_SendsRequest() throws IOException { } } + public void testInfer_SendsEmbeddingsRequest() throws IOException { + var modelId = "model"; + var apiKey = "apiKey"; + var input = "input"; + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + String responseJson = """ + { + "embeddings": [ + { + "values": [ + 0.0123, + -0.0123 + ] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = GoogleAiStudioEmbeddingsModelTests.createModel(modelId, apiKey, getUrl(webServer)); + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of(input), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + var result = listener.actionGet(TIMEOUT); + + assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); + assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests().get(0).getUri().getQuery(), endsWith(apiKey)); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), Matchers.equalTo(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap, aMapWithSize(1)); + assertThat( + requestMap.get("requests"), + Matchers.is( + List.of( + Map.of( + "model", + Strings.format("%s/%s", "models", modelId), + "content", + Map.of("parts", List.of(Map.of("text", input))) + ) + ) + ) + ); + } + } + + public void testChunkedInfer_Batches() throws IOException { + var modelId = "modelId"; + var apiKey = "apiKey"; + var input = List.of("foo", "bar"); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + String responseJson = """ + { + "embeddings": [ + { + "values": [ + 0.0123, + -0.0123 + ] + }, + { + "values": [ + 0.0456, + -0.0456 + ] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = GoogleAiStudioEmbeddingsModelTests.createModel(modelId, apiKey, getUrl(webServer)); + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + input, + new HashMap<>(), + InputType.INGEST, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + assertThat(results, hasSize(2)); + + // first result + { + assertThat(results.get(0), instanceOf(ChunkedTextEmbeddingFloatResults.class)); + var floatResult = (ChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals(input.get(0), floatResult.chunks().get(0).matchedText()); + assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding())); + } + + // second result + { + assertThat(results.get(1), instanceOf(ChunkedTextEmbeddingFloatResults.class)); + var floatResult = (ChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(floatResult.chunks(), hasSize(1)); + assertEquals(input.get(1), floatResult.chunks().get(0).matchedText()); + assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding())); + } + + assertThat(webServer.requests(), hasSize(1)); + assertThat(webServer.requests().get(0).getUri().getQuery(), endsWith(apiKey)); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), Matchers.equalTo(XContentType.JSON.mediaType())); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap, aMapWithSize(1)); + assertThat( + requestMap.get("requests"), + is( + List.of( + Map.of( + "model", + Strings.format("%s/%s", "models", modelId), + "content", + Map.of("parts", List.of(Map.of("text", input.get(0)))) + ), + Map.of( + "model", + Strings.format("%s/%s", "models", modelId), + "content", + Map.of("parts", List.of(Map.of("text", input.get(1)))) + ) + ) + ) + ); + } + } + public void testInfer_ResourceNotFound() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); @@ -555,7 +769,7 @@ public void testInfer_ResourceNotFound() throws IOException { """; webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); - var model = createModel("model", getUrl(webServer), "secret"); + var model = GoogleAiStudioCompletionModelTests.createModel("model", getUrl(webServer), "secret"); PlainActionFuture listener = new PlainActionFuture<>(); service.infer( model, @@ -574,6 +788,132 @@ public void testInfer_ResourceNotFound() throws IOException { } } + public void testCheckModelConfig_UpdatesDimensions() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var similarityMeasure = SimilarityMeasure.DOT_PRODUCT; + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + String responseJson = """ + { + "embeddings": [ + { + "values": [ + 0.0123, + -0.0123 + ] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = GoogleAiStudioEmbeddingsModelTests.createModel(getUrl(webServer), modelId, apiKey, 1, similarityMeasure); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + + // Updates dimensions to two as two embeddings were returned instead of one as specified before + assertThat( + result, + is(GoogleAiStudioEmbeddingsModelTests.createModel(getUrl(webServer), modelId, apiKey, 2, similarityMeasure)) + ); + } + } + + public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var oneDimension = 1; + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + String responseJson = """ + { + "embeddings": [ + { + "values": [ + 0.0123 + ] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = GoogleAiStudioEmbeddingsModelTests.createModel(getUrl(webServer), modelId, apiKey, oneDimension, null); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + + assertThat( + result, + is( + GoogleAiStudioEmbeddingsModelTests.createModel( + getUrl(webServer), + modelId, + apiKey, + oneDimension, + SimilarityMeasure.DOT_PRODUCT + ) + ) + ); + } + } + + public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + var oneDimension = 1; + var modelId = "model"; + var apiKey = "apiKey"; + + try (var service = new GoogleAiStudioService(senderFactory, createWithEmptySettings(threadPool))) { + String responseJson = """ + { + "embeddings": [ + { + "values": [ + 0.0123 + ] + } + ] + } + """; + + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = GoogleAiStudioEmbeddingsModelTests.createModel( + getUrl(webServer), + modelId, + apiKey, + oneDimension, + SimilarityMeasure.COSINE + ); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + var result = listener.actionGet(TIMEOUT); + + assertThat( + result, + is( + GoogleAiStudioEmbeddingsModelTests.createModel( + getUrl(webServer), + modelId, + apiKey, + oneDimension, + SimilarityMeasure.COSINE + ) + ) + ); + } + } + public static Map buildExpectationCompletions(List completions) { return Map.of( ChatCompletionResults.COMPLETION, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java index 1f8233f7eb103..025317fbe025a 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/completion/GoogleAiStudioCompletionModelTests.java @@ -11,7 +11,7 @@ import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import java.net.URISyntaxException; import java.util.HashMap; @@ -48,7 +48,7 @@ public static GoogleAiStudioCompletionModel createModel(String model, String api "service", new GoogleAiStudioCompletionServiceSettings(model, null), EmptyTaskSettings.INSTANCE, - new GoogleAiStudioSecretSettings(new SecureString(apiKey.toCharArray())) + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } @@ -60,7 +60,7 @@ public static GoogleAiStudioCompletionModel createModel(String model, String url url, new GoogleAiStudioCompletionServiceSettings(model, null), EmptyTaskSettings.INSTANCE, - new GoogleAiStudioSecretSettings(new SecureString(apiKey.toCharArray())) + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModelTests.java new file mode 100644 index 0000000000000..5ea9bbfc9d970 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsModelTests.java @@ -0,0 +1,66 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio.embeddings; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; + +public class GoogleAiStudioEmbeddingsModelTests extends ESTestCase { + + public static GoogleAiStudioEmbeddingsModel createModel(String model, String apiKey, String url) { + return new GoogleAiStudioEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "service", + url, + new GoogleAiStudioEmbeddingsServiceSettings(model, null, null, SimilarityMeasure.DOT_PRODUCT, null), + EmptyTaskSettings.INSTANCE, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static GoogleAiStudioEmbeddingsModel createModel( + String url, + String model, + String apiKey, + Integer dimensions, + @Nullable SimilarityMeasure similarityMeasure + ) { + return new GoogleAiStudioEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "service", + url, + new GoogleAiStudioEmbeddingsServiceSettings(model, null, dimensions, similarityMeasure, null), + EmptyTaskSettings.INSTANCE, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + + public static GoogleAiStudioEmbeddingsModel createModel( + String model, + String apiKey, + @Nullable Integer tokenLimit, + @Nullable Integer dimensions + ) { + return new GoogleAiStudioEmbeddingsModel( + "id", + TaskType.TEXT_EMBEDDING, + "service", + new GoogleAiStudioEmbeddingsServiceSettings(model, tokenLimit, dimensions, SimilarityMeasure.DOT_PRODUCT, null), + EmptyTaskSettings.INSTANCE, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..b5fbd28b476ba --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/embeddings/GoogleAiStudioEmbeddingsServiceSettingsTests.java @@ -0,0 +1,114 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.googleaistudio.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.AbstractWireSerializingTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettingsTests; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.MatchersUtils.equalToIgnoringWhitespaceInJsonString; +import static org.elasticsearch.xpack.inference.Utils.randomSimilarityMeasure; +import static org.hamcrest.Matchers.is; + +public class GoogleAiStudioEmbeddingsServiceSettingsTests extends AbstractWireSerializingTestCase { + + private static GoogleAiStudioEmbeddingsServiceSettings createRandom() { + return new GoogleAiStudioEmbeddingsServiceSettings( + randomAlphaOfLength(8), + randomFrom(randomNonNegativeInt(), null), + randomFrom(randomNonNegativeInt(), null), + randomFrom(randomSimilarityMeasure(), null), + randomFrom(RateLimitSettingsTests.createRandom(), null) + ); + } + + public void testFromMap_Request_CreatesSettingsCorrectly() { + var model = randomAlphaOfLength(8); + var maxInputTokens = randomIntBetween(1, 1024); + var dims = randomIntBetween(1, 10000); + var similarity = randomSimilarityMeasure(); + + var serviceSettings = GoogleAiStudioEmbeddingsServiceSettings.fromMap( + new HashMap<>( + Map.of( + ServiceFields.MODEL_ID, + model, + ServiceFields.MAX_INPUT_TOKENS, + maxInputTokens, + ServiceFields.DIMENSIONS, + dims, + ServiceFields.SIMILARITY, + similarity.toString() + ) + ) + ); + + assertThat(serviceSettings, is(new GoogleAiStudioEmbeddingsServiceSettings(model, maxInputTokens, dims, similarity, null))); + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new GoogleAiStudioEmbeddingsServiceSettings("model", 1024, 8, SimilarityMeasure.DOT_PRODUCT, null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model_id":"model", + "max_input_tokens": 1024, + "dimensions": 8, + "similarity": "dot_product", + "rate_limit": { + "requests_per_minute":360 + } + }""")); + } + + public void testToFilteredXContent_WritesAllValues_Except_RateLimit() throws IOException { + var entity = new GoogleAiStudioEmbeddingsServiceSettings("model", 1024, 8, SimilarityMeasure.DOT_PRODUCT, null); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + var filteredXContent = entity.getFilteredXContentObject(); + filteredXContent.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, equalToIgnoringWhitespaceInJsonString(""" + { + "model_id":"model", + "max_input_tokens": 1024, + "dimensions": 8, + "similarity": "dot_product" + }""")); + } + + @Override + protected Writeable.Reader instanceReader() { + return GoogleAiStudioEmbeddingsServiceSettings::new; + } + + @Override + protected GoogleAiStudioEmbeddingsServiceSettings createTestInstance() { + return createRandom(); + } + + @Override + protected GoogleAiStudioEmbeddingsServiceSettings mutateInstance(GoogleAiStudioEmbeddingsServiceSettings instance) throws IOException { + return randomValueOtherThan(instance, GoogleAiStudioEmbeddingsServiceSettingsTests::createRandom); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModelTests.java index 2ad2c12b4a97c..d7a62256f8d9c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserModelTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; import static org.hamcrest.Matchers.containsString; @@ -26,7 +27,7 @@ public static HuggingFaceElserModel createModel(String url, String apiKey) { TaskType.SPARSE_EMBEDDING, "service", new HuggingFaceElserServiceSettings(url), - new HuggingFaceElserSecretSettings(new SecureString(apiKey.toCharArray())) + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } @@ -36,7 +37,7 @@ public static HuggingFaceElserModel createModel(String url, String apiKey, Strin TaskType.SPARSE_EMBEDDING, "service", new HuggingFaceElserServiceSettings(url), - new HuggingFaceElserSecretSettings(new SecureString(apiKey.toCharArray())) + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) ); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserSecretSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserSecretSettingsTests.java deleted file mode 100644 index f69a9b5a967e0..0000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserSecretSettingsTests.java +++ /dev/null @@ -1,82 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.services.huggingface.elser; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.common.ValidationException; -import org.elasticsearch.common.io.stream.Writeable; -import org.elasticsearch.common.settings.SecureString; -import org.elasticsearch.test.AbstractWireSerializingTestCase; - -import java.io.IOException; -import java.util.HashMap; -import java.util.Map; - -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.is; - -public class HuggingFaceElserSecretSettingsTests extends AbstractWireSerializingTestCase { - - public static HuggingFaceElserSecretSettings createRandom() { - return new HuggingFaceElserSecretSettings(new SecureString(randomAlphaOfLength(15).toCharArray())); - } - - public void testFromMap() { - var apiKey = "abc"; - var serviceSettings = HuggingFaceElserSecretSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserSecretSettings.API_KEY, apiKey))); - - assertThat(new HuggingFaceElserSecretSettings(new SecureString(apiKey.toCharArray())), is(serviceSettings)); - } - - public void testFromMap_ReturnsNull_WhenMapIsNull() { - assertNull(HuggingFaceElserSecretSettings.fromMap(null)); - } - - public void testFromMap_MissingApiKey_ThrowsError() { - var thrownException = expectThrows(ValidationException.class, () -> HuggingFaceElserSecretSettings.fromMap(new HashMap<>())); - - assertThat( - thrownException.getMessage(), - containsString( - Strings.format("[secret_settings] does not contain the required setting [%s]", HuggingFaceElserSecretSettings.API_KEY) - ) - ); - } - - public void testFromMap_EmptyApiKey_ThrowsError() { - var thrownException = expectThrows( - ValidationException.class, - () -> HuggingFaceElserSecretSettings.fromMap(new HashMap<>(Map.of(HuggingFaceElserSecretSettings.API_KEY, ""))) - ); - - assertThat( - thrownException.getMessage(), - containsString( - Strings.format( - "[secret_settings] Invalid value empty string. [%s] must be a non-empty string", - HuggingFaceElserSecretSettings.API_KEY - ) - ) - ); - } - - @Override - protected Writeable.Reader instanceReader() { - return HuggingFaceElserSecretSettings::new; - } - - @Override - protected HuggingFaceElserSecretSettings createTestInstance() { - return createRandom(); - } - - @Override - protected HuggingFaceElserSecretSettings mutateInstance(HuggingFaceElserSecretSettings instance) throws IOException { - return randomValueOtherThan(instance, HuggingFaceElserSecretSettingsTests::createRandom); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java new file mode 100644 index 0000000000000..3ead273e78110 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -0,0 +1,651 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral; + +import org.apache.http.HttpHeaders; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.PlainActionFuture; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.ChunkedInferenceServiceResults; +import org.elasticsearch.inference.ChunkingOptions; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.InputType; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.ModelSecrets; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.http.MockResponse; +import org.elasticsearch.test.http.MockWebServer; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.core.inference.results.ChunkedTextEmbeddingFloatResults; +import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; +import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.logging.ThrottlerManager; +import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingModelTests; +import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsModel; +import org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettings; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.hamcrest.Matchers; +import org.junit.After; +import org.junit.Before; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +import static org.elasticsearch.xpack.inference.Utils.getInvalidModel; +import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; +import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; +import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; +import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; +import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; +import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.API_KEY_FIELD; +import static org.elasticsearch.xpack.inference.services.mistral.embeddings.MistralEmbeddingsServiceSettingsTests.createRequestSettingsMap; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasSize; +import static org.hamcrest.Matchers.instanceOf; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +public class MistralServiceTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); + private final MockWebServer webServer = new MockWebServer(); + private ThreadPool threadPool; + private HttpClientManager clientManager; + + @Before + public void init() throws Exception { + webServer.start(); + threadPool = createThreadPool(inferenceUtilityPool()); + clientManager = HttpClientManager.create(Settings.EMPTY, threadPool, mockClusterServiceEmpty(), mock(ThrottlerManager.class)); + } + + @After + public void shutdown() throws IOException { + clientManager.close(); + terminate(threadPool); + webServer.close(); + } + + public void testParseRequestConfig_CreatesAMistralEmbeddingsModel() throws IOException { + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap(model -> { + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + var serviceSettings = (MistralEmbeddingsServiceSettings) model.getServiceSettings(); + assertThat(serviceSettings.model(), is("mistral-embed")); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + }, exception -> fail("Unexpected exception: " + exception)); + + service.parseRequestConfig( + "id", + TaskType.TEXT_EMBEDDING, + getRequestConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), + getEmbeddingsTaskSettingsMap(), + getSecretSettingsMap("secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_ThrowsUnsupportedModelType() throws IOException { + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [mistral] service does not support task type [sparse_embedding]")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.SPARSE_EMBEDDING, + getRequestConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), + getEmbeddingsTaskSettingsMap(), + getSecretSettingsMap("secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createService()) { + var config = getRequestConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), + getEmbeddingsTaskSettingsMap(), + getSecretSettingsMap("secret") + ); + config.put("extra_key", "value"); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [mistral] service") + ); + } + ); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingTaskSettingsMap() throws IOException { + try (var service = createService()) { + var taskSettings = new HashMap(); + taskSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), + taskSettings, + getSecretSettingsMap("secret") + ); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [mistral] service") + ); + } + ); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + } + } + + public void testParseRequestConfig_ThrowsWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException { + try (var service = createService()) { + var secretSettings = getSecretSettingsMap("secret"); + secretSettings.put("extra_key", "value"); + + var config = getRequestConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), + getEmbeddingsTaskSettingsMap(), + secretSettings + ); + + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat( + exception.getMessage(), + is("Model configuration contains settings [{extra_key=value}] unknown to the [mistral] service") + ); + } + ); + + service.parseRequestConfig("id", TaskType.TEXT_EMBEDDING, config, Set.of(), modelVerificationListener); + } + } + + public void testParsePersistedConfig_CreatesAMistralEmbeddingsModel() throws IOException { + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), + getEmbeddingsTaskSettingsMap(), + getSecretSettingsMap("secret") + ); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().model(), is("mistral-embed")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + assertThat(embeddingsModel.getSecretSettings().apiKey().toString(), is("secret")); + } + } + + public void testParsePersistedConfig_ThrowsUnsupportedModelType() throws IOException { + try (var service = createService()) { + ActionListener modelVerificationListener = ActionListener.wrap( + model -> fail("Expected exception, but got model: " + model), + exception -> { + assertThat(exception, instanceOf(ElasticsearchStatusException.class)); + assertThat(exception.getMessage(), is("The [mistral] service does not support task type [sparse_embedding]")); + } + ); + + service.parseRequestConfig( + "id", + TaskType.SPARSE_EMBEDDING, + getRequestConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), + getEmbeddingsTaskSettingsMap(), + getSecretSettingsMap("secret") + ), + Set.of(), + modelVerificationListener + ); + } + } + + public void testParsePersistedConfigWithSecrets_ThrowsErrorTryingToParseInvalidModel() throws IOException { + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", null, null, null), + getEmbeddingsTaskSettingsMap(), + getSecretSettingsMap("secret") + ); + + var thrownException = expectThrows( + ElasticsearchStatusException.class, + () -> service.parsePersistedConfigWithSecrets("id", TaskType.SPARSE_EMBEDDING, config.config(), config.secrets()) + ); + + assertThat( + thrownException.getMessage(), + is("Failed to parse stored model [id] for [mistral] service, please delete and add the service again") + ); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInConfig() throws IOException { + try (var service = createService()) { + var serviceSettings = getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null); + var taskSettings = getEmbeddingsTaskSettingsMap(); + var secretSettings = getSecretSettingsMap("secret"); + var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + config.config().put("extra_key", "value"); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenExtraKeyExistsInEmbeddingServiceSettingsMap() throws IOException { + try (var service = createService()) { + var serviceSettings = getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null); + serviceSettings.put("extra_key", "value"); + + var taskSettings = getEmbeddingsTaskSettingsMap(); + var secretSettings = getSecretSettingsMap("secret"); + var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInEmbeddingTaskSettingsMap() throws IOException { + try (var service = createService()) { + var serviceSettings = getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null); + var taskSettings = new HashMap(); + taskSettings.put("extra_key", "value"); + + var secretSettings = getSecretSettingsMap("secret"); + var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + } + } + + public void testParsePersistedConfig_DoesNotThrowWhenAnExtraKeyExistsInEmbeddingSecretSettingsMap() throws IOException { + try (var service = createService()) { + var serviceSettings = getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null); + var taskSettings = getEmbeddingsTaskSettingsMap(); + var secretSettings = getSecretSettingsMap("secret"); + secretSettings.put("extra_key", "value"); + + var config = getPersistedConfigMap(serviceSettings, taskSettings, secretSettings); + + var model = service.parsePersistedConfigWithSecrets("id", TaskType.TEXT_EMBEDDING, config.config(), config.secrets()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + } + } + + public void testParsePersistedConfig_WithoutSecretsCreatesEmbeddingsModel() throws IOException { + try (var service = createService()) { + var config = getPersistedConfigMap( + getEmbeddingsServiceSettingsMap("mistral-embed", 1024, 512, null), + getEmbeddingsTaskSettingsMap(), + Map.of() + ); + + var model = service.parsePersistedConfig("id", TaskType.TEXT_EMBEDDING, config.config()); + + assertThat(model, instanceOf(MistralEmbeddingsModel.class)); + + var embeddingsModel = (MistralEmbeddingsModel) model; + assertThat(embeddingsModel.getServiceSettings().model(), is("mistral-embed")); + assertThat(embeddingsModel.getServiceSettings().dimensions(), is(1024)); + assertThat(embeddingsModel.getServiceSettings().maxInputTokens(), is(512)); + } + } + + public void testCheckModelConfig_ForEmbeddingsModel_Works() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingResultJson)); + + var model = MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", null, null, null, null); + model.setURI(getUrl(webServer)); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.checkModelConfig(model, listener); + + var result = listener.actionGet(TIMEOUT); + assertThat( + result, + is(MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", 2, null, SimilarityMeasure.DOT_PRODUCT, null)) + ); + + assertThat(webServer.requests(), hasSize(1)); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + MatcherAssert.assertThat( + requestMap, + Matchers.is(Map.of("input", List.of("how big"), "encoding_format", "float", "model", "mistral-embed")) + ); + } + } + + public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws IOException { + var sender = mock(Sender.class); + + var factory = mock(HttpRequestSender.Factory.class); + when(factory.createSender(anyString())).thenReturn(sender); + + var mockModel = getInvalidModel("model_id", "service_name"); + + try (var service = new MistralService(factory, createWithEmptySettings(threadPool))) { + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + mockModel, + null, + List.of(""), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); + assertThat( + thrownException.getMessage(), + is("The internal model was invalid, please delete the service [service_name] with id [model_id] and add it again.") + ); + + verify(factory, times(1)).createSender(anyString()); + verify(sender, times(1)).start(); + } + + verify(sender, times(1)).close(); + verifyNoMoreInteractions(factory); + verifyNoMoreInteractions(sender); + } + + public void testChunkedInfer_Embeddings_CallsInfer_ConvertsFloatResponse() throws IOException, URISyntaxException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.123, + -0.123 + ] + }, + { + "object": "embedding", + "index": 1, + "embedding": [ + 0.223, + -0.223 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var model = MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", null, null, null, null); + model.setURI(getUrl(webServer)); + + PlainActionFuture> listener = new PlainActionFuture<>(); + service.chunkedInfer( + model, + List.of("abc", "def"), + new HashMap<>(), + InputType.INGEST, + new ChunkingOptions(null, null), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var results = listener.actionGet(TIMEOUT); + + assertThat(results, hasSize(2)); + { + assertThat(results.get(0), CoreMatchers.instanceOf(ChunkedTextEmbeddingFloatResults.class)); + var floatResult = (ChunkedTextEmbeddingFloatResults) results.get(0); + assertThat(floatResult.chunks(), hasSize(1)); + assertTrue(Arrays.equals(new float[] { 0.123f, -0.123f }, floatResult.chunks().get(0).embedding())); + } + { + assertThat(results.get(1), CoreMatchers.instanceOf(ChunkedTextEmbeddingFloatResults.class)); + var floatResult = (ChunkedTextEmbeddingFloatResults) results.get(1); + assertThat(floatResult.chunks(), hasSize(1)); + assertTrue(Arrays.equals(new float[] { 0.223f, -0.223f }, floatResult.chunks().get(0).embedding())); + } + + assertThat(webServer.requests(), hasSize(1)); + assertNull(webServer.requests().get(0).getUri().getQuery()); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); + assertThat(webServer.requests().get(0).getHeader(HttpHeaders.AUTHORIZATION), equalTo("Bearer apikey")); + + var requestMap = entityAsMap(webServer.requests().get(0).getBody()); + assertThat(requestMap.size(), Matchers.is(3)); + assertThat(requestMap.get("input"), Matchers.is(List.of("abc", "def"))); + assertThat(requestMap.get("encoding_format"), Matchers.is("float")); + assertThat(requestMap.get("model"), Matchers.is("mistral-embed")); + } + } + + public void testInfer_ThrowsWhenQueryIsPresent() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(testEmbeddingResultJson)); + + var model = MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", null, null, null, null); + model.setURI(getUrl(webServer)); + + PlainActionFuture listener = new PlainActionFuture<>(); + UnsupportedOperationException exception = expectThrows( + UnsupportedOperationException.class, + () -> service.infer( + model, + "should throw", + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ) + ); + + assertThat(exception.getMessage(), is("Mistral service does not support inference with query input")); + } + } + + public void testInfer_UnauthorisedResponse() throws IOException { + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + + try (var service = new MistralService(senderFactory, createWithEmptySettings(threadPool))) { + + String responseJson = """ + { + "error": { + "message": "Incorrect API key provided:", + "type": "invalid_request_error", + "param": null, + "code": "invalid_api_key" + } + } + """; + webServer.enqueue(new MockResponse().setResponseCode(401).setBody(responseJson)); + + var model = MistralEmbeddingModelTests.createModel("id", "mistral-embed", "apikey", null, null, null, null); + model.setURI(getUrl(webServer)); + + PlainActionFuture listener = new PlainActionFuture<>(); + service.infer( + model, + null, + List.of("abc"), + new HashMap<>(), + InputType.INGEST, + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); + assertThat(error.getMessage(), containsString("Received an authentication error status code for request")); + assertThat(error.getMessage(), containsString("Error message: [Incorrect API key provided:]")); + assertThat(webServer.requests(), hasSize(1)); + } + } + + // ---------------------------------------------------------------- + + private MistralService createService() { + return new MistralService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); + } + + private Map getRequestConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + var builtServiceSettings = new HashMap<>(); + builtServiceSettings.putAll(serviceSettings); + builtServiceSettings.putAll(secretSettings); + + return new HashMap<>( + Map.of(ModelConfigurations.SERVICE_SETTINGS, builtServiceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings) + ); + } + + private record PeristedConfigRecord(Map config, Map secrets) {} + + private PeristedConfigRecord getPersistedConfigMap( + Map serviceSettings, + Map taskSettings, + Map secretSettings + ) { + + return new PeristedConfigRecord( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + new HashMap<>(Map.of(ModelSecrets.SECRET_SETTINGS, secretSettings)) + ); + } + + private PeristedConfigRecord getPersistedConfigMap(Map serviceSettings, Map taskSettings) { + + return new PeristedConfigRecord( + new HashMap<>(Map.of(ModelConfigurations.SERVICE_SETTINGS, serviceSettings, ModelConfigurations.TASK_SETTINGS, taskSettings)), + null + ); + } + + private static Map getEmbeddingsServiceSettingsMap( + String model, + @Nullable Integer dimensions, + @Nullable Integer maxTokens, + @Nullable SimilarityMeasure similarityMeasure + ) { + return createRequestSettingsMap(model, dimensions, maxTokens, similarityMeasure); + } + + private static Map getEmbeddingsTaskSettingsMap() { + // no task settings for Mistral embeddings + return Map.of(); + } + + private static Map getSecretSettingsMap(String apiKey) { + return new HashMap<>(Map.of(API_KEY_FIELD, apiKey)); + } + + private static final String testEmbeddingResultJson = """ + { + "object": "list", + "data": [ + { + "object": "embedding", + "index": 0, + "embedding": [ + 0.0123, + -0.0123 + ] + } + ], + "model": "text-embedding-ada-002-v2", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + """; + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java new file mode 100644 index 0000000000000..0fe8723664c6e --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingModelTests.java @@ -0,0 +1,42 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral.embeddings; + +import org.elasticsearch.common.settings.SecureString; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.EmptyTaskSettings; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; + +public class MistralEmbeddingModelTests extends ESTestCase { + public static MistralEmbeddingsModel createModel(String inferenceId, String model, String apiKey) { + return createModel(inferenceId, model, apiKey, null, null, null, null); + } + + public static MistralEmbeddingsModel createModel( + String inferenceId, + String model, + String apiKey, + @Nullable Integer dimensions, + @Nullable Integer maxTokens, + @Nullable SimilarityMeasure similarity, + RateLimitSettings rateLimitSettings + ) { + return new MistralEmbeddingsModel( + inferenceId, + TaskType.TEXT_EMBEDDING, + "mistral", + new MistralEmbeddingsServiceSettings(model, dimensions, maxTokens, similarity, rateLimitSettings), + EmptyTaskSettings.INSTANCE, + new DefaultSecretSettings(new SecureString(apiKey.toCharArray())) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java new file mode 100644 index 0000000000000..13f43a5f31ad3 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/embeddings/MistralEmbeddingsServiceSettingsTests.java @@ -0,0 +1,149 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.services.mistral.embeddings; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.io.stream.ByteArrayStreamInput; +import org.elasticsearch.common.io.stream.BytesStreamOutput; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; +import org.elasticsearch.xpack.inference.services.ServiceFields; +import org.elasticsearch.xpack.inference.services.mistral.MistralConstants; +import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings; +import org.hamcrest.CoreMatchers; + +import java.io.IOException; +import java.util.HashMap; +import java.util.Map; + +import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY; +import static org.hamcrest.Matchers.is; + +public class MistralEmbeddingsServiceSettingsTests extends ESTestCase { + public void testFromMap_Request_CreatesSettingsCorrectly() { + var model = "mistral-embed"; + var dims = 1536; + var maxInputTokens = 512; + var serviceSettings = MistralEmbeddingsServiceSettings.fromMap( + createRequestSettingsMap(model, dims, maxInputTokens, SimilarityMeasure.COSINE), + ConfigurationParseContext.REQUEST + ); + + assertThat(serviceSettings, is(new MistralEmbeddingsServiceSettings(model, dims, maxInputTokens, SimilarityMeasure.COSINE, null))); + } + + public void testFromMap_RequestWithRateLimit_CreatesSettingsCorrectly() { + var model = "mistral-embed"; + var dims = 1536; + var maxInputTokens = 512; + var settingsMap = createRequestSettingsMap(model, dims, maxInputTokens, SimilarityMeasure.COSINE); + settingsMap.put(RateLimitSettings.FIELD_NAME, new HashMap<>(Map.of(RateLimitSettings.REQUESTS_PER_MINUTE_FIELD, 3))); + + var serviceSettings = MistralEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.REQUEST); + + assertThat( + serviceSettings, + is(new MistralEmbeddingsServiceSettings(model, dims, maxInputTokens, SimilarityMeasure.COSINE, new RateLimitSettings(3))) + ); + } + + public void testFromMap_Persistent_CreatesSettingsCorrectly() { + var model = "mistral-embed"; + var dims = 1536; + var maxInputTokens = 512; + + var settingsMap = createRequestSettingsMap(model, dims, maxInputTokens, SimilarityMeasure.COSINE); + var serviceSettings = MistralEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.PERSISTENT); + + assertThat(serviceSettings, is(new MistralEmbeddingsServiceSettings(model, dims, maxInputTokens, SimilarityMeasure.COSINE, null))); + } + + public void testFromMap_PersistentContext_DoesNotThrowException_WhenDimensionsIsNull() { + var model = "mistral-embed"; + + var settingsMap = createRequestSettingsMap(model, null, null, null); + var serviceSettings = MistralEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.PERSISTENT); + + assertThat(serviceSettings, is(new MistralEmbeddingsServiceSettings(model, null, null, null, null))); + } + + public void testFromMap_PersistentContext_DoesNotThrowException_WhenSimilarityIsPresent() { + var model = "mistral-embed"; + + var settingsMap = createRequestSettingsMap(model, null, null, SimilarityMeasure.DOT_PRODUCT); + var serviceSettings = MistralEmbeddingsServiceSettings.fromMap(settingsMap, ConfigurationParseContext.PERSISTENT); + + assertThat(serviceSettings, is(new MistralEmbeddingsServiceSettings(model, null, null, SimilarityMeasure.DOT_PRODUCT, null))); + } + + public void testToXContent_WritesAllValues() throws IOException { + var entity = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + entity.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"model":"model_name","dimensions":1024,"max_input_tokens":512,""" + """ + "rate_limit":{"requests_per_minute":3}}""")); + } + + public void testToFilteredXContent_WritesFilteredValues() throws IOException { + var entity = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3)); + + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + var filteredXContent = entity.getFilteredXContentObject(); + filteredXContent.toXContent(builder, null); + String xContentResult = Strings.toString(builder); + + assertThat(xContentResult, CoreMatchers.is(""" + {"model":"model_name","dimensions":1024,"max_input_tokens":512}""")); + } + + public void testStreamInputAndOutput_WritesValuesCorrectly() throws IOException { + var outputBuffer = new BytesStreamOutput(); + var settings = new MistralEmbeddingsServiceSettings("model_name", 1024, 512, null, new RateLimitSettings(3)); + settings.writeTo(outputBuffer); + + var outputBufferRef = outputBuffer.bytes(); + var inputBuffer = new ByteArrayStreamInput(outputBufferRef.array()); + + var settingsFromBuffer = new MistralEmbeddingsServiceSettings(inputBuffer); + + assertEquals(settings, settingsFromBuffer); + } + + public static HashMap createRequestSettingsMap( + String model, + @Nullable Integer dimensions, + @Nullable Integer maxTokens, + @Nullable SimilarityMeasure similarityMeasure + ) { + var map = new HashMap(Map.of(MistralConstants.MODEL_FIELD, model)); + + if (dimensions != null) { + map.put(ServiceFields.DIMENSIONS, dimensions); + } + + if (maxTokens != null) { + map.put(ServiceFields.MAX_INPUT_TOKENS, maxTokens); + } + + if (similarityMeasure != null) { + map.put(SIMILARITY, similarityMeasure.toString()); + } + + return map; + } + +}