diff --git a/docs/changelog/117589.yaml b/docs/changelog/117589.yaml new file mode 100644 index 000000000000..e6880fd9477b --- /dev/null +++ b/docs/changelog/117589.yaml @@ -0,0 +1,5 @@ +pr: 117589 +summary: "Add Inference Unified API for chat completions for OpenAI" +area: Machine Learning +type: enhancement +issues: [] diff --git a/docs/changelog/117657.yaml b/docs/changelog/117657.yaml new file mode 100644 index 000000000000..0a72e9dabe9e --- /dev/null +++ b/docs/changelog/117657.yaml @@ -0,0 +1,5 @@ +pr: 117657 +summary: Ignore cancellation exceptions +area: ES|QL +type: bug +issues: [] diff --git a/docs/changelog/118064.yaml b/docs/changelog/118064.yaml new file mode 100644 index 000000000000..7d12f365bf14 --- /dev/null +++ b/docs/changelog/118064.yaml @@ -0,0 +1,5 @@ +pr: 118064 +summary: Add Highlighter for Semantic Text Fields +area: Highlighting +type: feature +issues: [] diff --git a/docs/plugins/analysis-kuromoji.asciidoc b/docs/plugins/analysis-kuromoji.asciidoc index 0a167bf3f024..217d88f36122 100644 --- a/docs/plugins/analysis-kuromoji.asciidoc +++ b/docs/plugins/analysis-kuromoji.asciidoc @@ -750,3 +750,39 @@ Which results in: ] } -------------------------------------------------- + +[[analysis-kuromoji-completion]] +==== `kuromoji_completion` token filter + +The `kuromoji_completion` token filter adds Japanese romanized tokens to the term attributes along with the original tokens (surface forms). + +[source,console] +-------------------------------------------------- +GET _analyze +{ + "analyzer": "kuromoji_completion", + "text": "寿司" <1> +} +-------------------------------------------------- + +<1> Returns `寿司`, `susi` (Kunrei-shiki) and `sushi` (Hepburn-shiki). + +The `kuromoji_completion` token filter accepts the following settings: + +`mode`:: ++ +-- + +The tokenization mode determines how the tokenizer handles compound and +unknown words. It can be set to: + +`index`:: + + Simple romanization. Expected to be used when indexing. + +`query`:: + + Input Method aware romanization. Expected to be used when querying. + +Defaults to `index`. +-- diff --git a/docs/reference/mapping/types/semantic-text.asciidoc b/docs/reference/mapping/types/semantic-text.asciidoc index f76a9352c2fe..b3e103ec6dbd 100644 --- a/docs/reference/mapping/types/semantic-text.asciidoc +++ b/docs/reference/mapping/types/semantic-text.asciidoc @@ -112,50 +112,43 @@ Trying to <> that is used on a {infer-cap} endpoints have a limit on the amount of text they can process. To allow for large amounts of text to be used in semantic search, `semantic_text` automatically generates smaller passages if needed, called _chunks_. -Each chunk will include the text subpassage and the corresponding embedding generated from it. +Each chunk refers to a passage of the text and the corresponding embedding generated from it. When querying, the individual passages will be automatically searched for each document, and the most relevant passage will be used to compute a score. For more details on chunking and how to configure chunking settings, see <> in the Inference API documentation. +Refer to <> to learn more about +semantic search using `semantic_text` and the `semantic` query. [discrete] -[[semantic-text-structure]] -==== `semantic_text` structure +[[semantic-text-highlighting]] +==== Extracting Relevant Fragments from Semantic Text -Once a document is ingested, a `semantic_text` field will have the following structure: +You can extract the most relevant fragments from a semantic text field by using the <> in the <>. -[source,console-result] +[source,console] ------------------------------------------------------------ -"inference_field": { - "text": "these are not the droids you're looking for", <1> - "inference": { - "inference_id": "my-elser-endpoint", <2> - "model_settings": { <3> - "task_type": "sparse_embedding" +PUT test-index +{ + "query": { + "semantic": { + "field": "my_semantic_field" + } }, - "chunks": [ <4> - { - "text": "these are not the droids you're looking for", - "embeddings": { - (...) + "highlight": { + "fields": { + "my_semantic_field": { + "type": "semantic", + "number_of_fragments": 2, <1> + "order": "score" <2> + } } - } - ] - } + } } ------------------------------------------------------------ -// TEST[skip:TBD] -<1> The field will become an object structure to accommodate both the original -text and the inference results. -<2> The `inference_id` used to generate the embeddings. -<3> Model settings, including the task type and dimensions/similarity if -applicable. -<4> Inference results will be grouped in chunks, each with its corresponding -text and embeddings. - -Refer to <> to learn more about -semantic search using `semantic_text` and the `semantic` query. - +// TEST[skip:Requires inference endpoint] +<1> Specifies the maximum number of fragments to return. +<2> Sorts highlighted fragments by score when set to `score`. By default, fragments will be output in the order they appear in the field (order: none). [discrete] [[custom-indexing]] diff --git a/modules/repository-s3/src/javaRestTest/java/org/elasticsearch/repositories/s3/AbstractRepositoryS3RestTestCase.java b/modules/repository-s3/src/javaRestTest/java/org/elasticsearch/repositories/s3/AbstractRepositoryS3RestTestCase.java index 2199a6452175..67ada622efee 100644 --- a/modules/repository-s3/src/javaRestTest/java/org/elasticsearch/repositories/s3/AbstractRepositoryS3RestTestCase.java +++ b/modules/repository-s3/src/javaRestTest/java/org/elasticsearch/repositories/s3/AbstractRepositoryS3RestTestCase.java @@ -19,6 +19,7 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.test.rest.ObjectPath; import java.io.Closeable; import java.io.IOException; @@ -27,7 +28,6 @@ import java.util.function.UnaryOperator; import java.util.stream.Collectors; -import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.equalTo; @@ -152,10 +152,9 @@ private void testNonexistentBucket(Boolean readonly) throws Exception { final var responseException = expectThrows(ResponseException.class, () -> client().performRequest(registerRequest)); assertEquals(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), responseException.getResponse().getStatusLine().getStatusCode()); - assertThat( - responseException.getMessage(), - allOf(containsString("repository_verification_exception"), containsString("is not accessible on master node")) - ); + final var responseObjectPath = ObjectPath.createFromResponse(responseException.getResponse()); + assertThat(responseObjectPath.evaluate("error.type"), equalTo("repository_verification_exception")); + assertThat(responseObjectPath.evaluate("error.reason"), containsString("is not accessible on master node")); } public void testNonexistentClient() throws Exception { @@ -181,15 +180,11 @@ private void testNonexistentClient(Boolean readonly) throws Exception { final var responseException = expectThrows(ResponseException.class, () -> client().performRequest(registerRequest)); assertEquals(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), responseException.getResponse().getStatusLine().getStatusCode()); - assertThat( - responseException.getMessage(), - allOf( - containsString("repository_verification_exception"), - containsString("is not accessible on master node"), - containsString("illegal_argument_exception"), - containsString("Unknown s3 client name") - ) - ); + final var responseObjectPath = ObjectPath.createFromResponse(responseException.getResponse()); + assertThat(responseObjectPath.evaluate("error.type"), equalTo("repository_verification_exception")); + assertThat(responseObjectPath.evaluate("error.reason"), containsString("is not accessible on master node")); + assertThat(responseObjectPath.evaluate("error.caused_by.type"), equalTo("illegal_argument_exception")); + assertThat(responseObjectPath.evaluate("error.caused_by.reason"), containsString("Unknown s3 client name")); } public void testNonexistentSnapshot() throws Exception { @@ -212,7 +207,8 @@ private void testNonexistentSnapshot(Boolean readonly) throws Exception { final var getSnapshotRequest = new Request("GET", "/_snapshot/" + repositoryName + "/" + randomIdentifier()); final var getSnapshotException = expectThrows(ResponseException.class, () -> client().performRequest(getSnapshotRequest)); assertEquals(RestStatus.NOT_FOUND.getStatus(), getSnapshotException.getResponse().getStatusLine().getStatusCode()); - assertThat(getSnapshotException.getMessage(), containsString("snapshot_missing_exception")); + final var getResponseObjectPath = ObjectPath.createFromResponse(getSnapshotException.getResponse()); + assertThat(getResponseObjectPath.evaluate("error.type"), equalTo("snapshot_missing_exception")); final var restoreRequest = new Request("POST", "/_snapshot/" + repositoryName + "/" + randomIdentifier() + "/_restore"); if (randomBoolean()) { @@ -220,13 +216,15 @@ private void testNonexistentSnapshot(Boolean readonly) throws Exception { } final var restoreException = expectThrows(ResponseException.class, () -> client().performRequest(restoreRequest)); assertEquals(RestStatus.INTERNAL_SERVER_ERROR.getStatus(), restoreException.getResponse().getStatusLine().getStatusCode()); - assertThat(restoreException.getMessage(), containsString("snapshot_restore_exception")); + final var restoreResponseObjectPath = ObjectPath.createFromResponse(restoreException.getResponse()); + assertThat(restoreResponseObjectPath.evaluate("error.type"), equalTo("snapshot_restore_exception")); if (readonly != Boolean.TRUE) { final var deleteRequest = new Request("DELETE", "/_snapshot/" + repositoryName + "/" + randomIdentifier()); final var deleteException = expectThrows(ResponseException.class, () -> client().performRequest(deleteRequest)); assertEquals(RestStatus.NOT_FOUND.getStatus(), deleteException.getResponse().getStatusLine().getStatusCode()); - assertThat(deleteException.getMessage(), containsString("snapshot_missing_exception")); + final var deleteResponseObjectPath = ObjectPath.createFromResponse(deleteException.getResponse()); + assertThat(deleteResponseObjectPath.evaluate("error.type"), equalTo("snapshot_missing_exception")); } } } diff --git a/muted-tests.yml b/muted-tests.yml index fb2fea908ef9..a36723469189 100644 --- a/muted-tests.yml +++ b/muted-tests.yml @@ -117,9 +117,6 @@ tests: - class: org.elasticsearch.xpack.deprecation.DeprecationHttpIT method: testDeprecatedSettingsReturnWarnings issue: https://github.com/elastic/elasticsearch/issues/108628 -- class: org.elasticsearch.action.search.SearchQueryThenFetchAsyncActionTests - method: testBottomFieldSort - issue: https://github.com/elastic/elasticsearch/issues/116249 - class: org.elasticsearch.xpack.shutdown.NodeShutdownIT method: testAllocationPreventedForRemoval issue: https://github.com/elastic/elasticsearch/issues/116363 @@ -242,12 +239,12 @@ tests: - class: org.elasticsearch.packaging.test.ConfigurationTests method: test30SymlinkedDataPath issue: https://github.com/elastic/elasticsearch/issues/118111 -- class: org.elasticsearch.datastreams.ResolveClusterDataStreamIT - method: testClusterResolveWithDataStreamsUsingAlias - issue: https://github.com/elastic/elasticsearch/issues/118124 - class: org.elasticsearch.packaging.test.KeystoreManagementTests method: test30KeystorePasswordFromFile issue: https://github.com/elastic/elasticsearch/issues/118123 +- class: org.elasticsearch.packaging.test.KeystoreManagementTests + method: test31WrongKeystorePasswordFromFile + issue: https://github.com/elastic/elasticsearch/issues/118123 - class: org.elasticsearch.packaging.test.ArchiveTests method: test41AutoconfigurationNotTriggeredWhenNodeCannotContainData issue: https://github.com/elastic/elasticsearch/issues/118110 @@ -260,6 +257,43 @@ tests: - class: org.elasticsearch.xpack.remotecluster.CrossClusterEsqlRCS2UnavailableRemotesIT method: testEsqlRcs2UnavailableRemoteScenarios issue: https://github.com/elastic/elasticsearch/issues/117419 +- class: org.elasticsearch.packaging.test.DebPreservationTests + method: test40RestartOnUpgrade + issue: https://github.com/elastic/elasticsearch/issues/118170 +- class: org.elasticsearch.xpack.inference.DefaultEndPointsIT + method: testInferDeploysDefaultRerank + issue: https://github.com/elastic/elasticsearch/issues/118184 +- class: org.elasticsearch.xpack.esql.action.EsqlActionTaskIT + method: testCancelRequestWhenFailingFetchingPages + issue: https://github.com/elastic/elasticsearch/issues/118193 +- class: org.elasticsearch.packaging.test.MemoryLockingTests + method: test20MemoryLockingEnabled + issue: https://github.com/elastic/elasticsearch/issues/118195 +- class: org.elasticsearch.packaging.test.ArchiveTests + method: test42AutoconfigurationNotTriggeredWhenNodeCannotBecomeMaster + issue: https://github.com/elastic/elasticsearch/issues/118196 +- class: org.elasticsearch.packaging.test.ArchiveTests + method: test43AutoconfigurationNotTriggeredWhenTlsAlreadyConfigured + issue: https://github.com/elastic/elasticsearch/issues/118202 +- class: org.elasticsearch.packaging.test.ArchiveTests + method: test44AutoConfigurationNotTriggeredOnNotWriteableConfDir + issue: https://github.com/elastic/elasticsearch/issues/118208 +- class: org.elasticsearch.packaging.test.ArchiveTests + method: test51AutoConfigurationWithPasswordProtectedKeystore + issue: https://github.com/elastic/elasticsearch/issues/118212 +- class: org.elasticsearch.xpack.inference.InferenceCrudIT + method: testUnifiedCompletionInference + issue: https://github.com/elastic/elasticsearch/issues/118210 +- class: org.elasticsearch.ingest.common.IngestCommonClientYamlTestSuiteIT + issue: https://github.com/elastic/elasticsearch/issues/118215 +- class: org.elasticsearch.datastreams.DataStreamsClientYamlTestSuiteIT + method: test {p0=data_stream/120_data_streams_stats/Multiple data stream} + issue: https://github.com/elastic/elasticsearch/issues/118217 +- class: org.elasticsearch.xpack.security.operator.OperatorPrivilegesIT + method: testEveryActionIsEitherOperatorOnlyOrNonOperator + issue: https://github.com/elastic/elasticsearch/issues/118220 +- class: org.elasticsearch.validation.DotPrefixClientYamlTestSuiteIT + issue: https://github.com/elastic/elasticsearch/issues/118224 # Examples: # diff --git a/rest-api-spec/src/main/resources/rest-api-spec/api/migrate.reindex.json b/rest-api-spec/src/main/resources/rest-api-spec/api/migrate.reindex.json new file mode 100644 index 000000000000..149a90bc198b --- /dev/null +++ b/rest-api-spec/src/main/resources/rest-api-spec/api/migrate.reindex.json @@ -0,0 +1,29 @@ +{ + "migrate.reindex":{ + "documentation":{ + "url":"https://www.elastic.co/guide/en/elasticsearch/reference/master/data-stream-reindex.html", + "description":"This API reindexes all legacy backing indices for a data stream. It does this in a persistent task. The persistent task id is returned immediately, and the reindexing work is completed in that task" + }, + "stability":"experimental", + "visibility":"private", + "headers":{ + "accept": [ "application/json"], + "content_type": ["application/json"] + }, + "url":{ + "paths":[ + { + "path":"/_migration/reindex", + "methods":[ + "POST" + ] + } + ] + }, + "body":{ + "description":"The body contains the fields `mode` and `source.index, where the only mode currently supported is `upgrade`, and the `source.index` must be a data stream name", + "required":true + } + } +} + diff --git a/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java b/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java index 2e78cc6f516b..6a5aa2943de9 100644 --- a/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java +++ b/server/src/main/java/org/elasticsearch/common/xcontent/ChunkedToXContentHelper.java @@ -12,6 +12,7 @@ import org.elasticsearch.common.collect.Iterators; import org.elasticsearch.xcontent.ToXContent; +import java.util.Collections; import java.util.Iterator; public enum ChunkedToXContentHelper { @@ -53,6 +54,14 @@ public static Iterator field(String name, String value) { return Iterators.single(((builder, params) -> builder.field(name, value))); } + public static Iterator optionalField(String name, String value) { + if (value == null) { + return Collections.emptyIterator(); + } else { + return field(name, value); + } + } + /** * Creates an Iterator of a single ToXContent object that serializes the given object as a single chunk. Just wraps {@link * Iterators#single}, but still useful because it avoids any type ambiguity. diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index 4497254aad1f..c2d690d8160a 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -112,6 +112,23 @@ void infer( ); /** + * Perform completion inference on the model using the unified schema. + * + * @param model The model + * @param request Parameters for the request + * @param timeout The timeout for the request + * @param listener Inference result listener + */ + void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ); + + /** + * Chunk long text. + * * @param model The model * @param query Inference query, mainly for re-ranking * @param input Inference input diff --git a/server/src/main/java/org/elasticsearch/inference/TaskType.java b/server/src/main/java/org/elasticsearch/inference/TaskType.java index b0e5bababbbc..fcb8ea721379 100644 --- a/server/src/main/java/org/elasticsearch/inference/TaskType.java +++ b/server/src/main/java/org/elasticsearch/inference/TaskType.java @@ -38,6 +38,10 @@ public static TaskType fromString(String name) { } public static TaskType fromStringOrStatusException(String name) { + if (name == null) { + throw new ElasticsearchStatusException("Task type must not be null", RestStatus.BAD_REQUEST); + } + try { TaskType taskType = TaskType.fromString(name); return Objects.requireNonNull(taskType); diff --git a/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java new file mode 100644 index 000000000000..e596be626b51 --- /dev/null +++ b/server/src/main/java/org/elasticsearch/inference/UnifiedCompletionRequest.java @@ -0,0 +1,425 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the "Elastic License + * 2.0", the "GNU Affero General Public License v3.0 only", and the "Server Side + * Public License v 1"; you may not use this file except in compliance with, at + * your election, the "Elastic License 2.0", the "GNU Affero General Public + * License v3.0 only", or the "Server Side Public License, v 1". + */ + +package org.elasticsearch.inference; + +import org.elasticsearch.common.io.stream.NamedWriteable; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.xcontent.XContentParserUtils; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentParseException; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.elasticsearch.xcontent.ConstructingObjectParser.constructorArg; +import static org.elasticsearch.xcontent.ConstructingObjectParser.optionalConstructorArg; + +public record UnifiedCompletionRequest( + List messages, + @Nullable String model, + @Nullable Long maxCompletionTokens, + @Nullable List stop, + @Nullable Float temperature, + @Nullable ToolChoice toolChoice, + @Nullable List tools, + @Nullable Float topP +) implements Writeable { + + public sealed interface Content extends NamedWriteable permits ContentObjects, ContentString {} + + @SuppressWarnings("unchecked") + public static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + UnifiedCompletionRequest.class.getSimpleName(), + args -> new UnifiedCompletionRequest( + (List) args[0], + (String) args[1], + (Long) args[2], + (List) args[3], + (Float) args[4], + (ToolChoice) args[5], + (List) args[6], + (Float) args[7] + ) + ); + + static { + PARSER.declareObjectArray(constructorArg(), Message.PARSER::apply, new ParseField("messages")); + PARSER.declareString(optionalConstructorArg(), new ParseField("model")); + PARSER.declareLong(optionalConstructorArg(), new ParseField("max_completion_tokens")); + PARSER.declareStringArray(optionalConstructorArg(), new ParseField("stop")); + PARSER.declareFloat(optionalConstructorArg(), new ParseField("temperature")); + PARSER.declareField( + optionalConstructorArg(), + (p, c) -> parseToolChoice(p), + new ParseField("tool_choice"), + ObjectParser.ValueType.OBJECT_OR_STRING + ); + PARSER.declareObjectArray(optionalConstructorArg(), Tool.PARSER::apply, new ParseField("tools")); + PARSER.declareFloat(optionalConstructorArg(), new ParseField("top_p")); + } + + public static List getNamedWriteables() { + return List.of( + new NamedWriteableRegistry.Entry(Content.class, ContentObjects.NAME, ContentObjects::new), + new NamedWriteableRegistry.Entry(Content.class, ContentString.NAME, ContentString::new), + new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceObject.NAME, ToolChoiceObject::new), + new NamedWriteableRegistry.Entry(ToolChoice.class, ToolChoiceString.NAME, ToolChoiceString::new) + ); + } + + public static UnifiedCompletionRequest of(List messages) { + return new UnifiedCompletionRequest(messages, null, null, null, null, null, null, null); + } + + public UnifiedCompletionRequest(StreamInput in) throws IOException { + this( + in.readCollectionAsImmutableList(Message::new), + in.readOptionalString(), + in.readOptionalVLong(), + in.readOptionalStringCollectionAsList(), + in.readOptionalFloat(), + in.readOptionalNamedWriteable(ToolChoice.class), + in.readOptionalCollectionAsList(Tool::new), + in.readOptionalFloat() + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(messages); + out.writeOptionalString(model); + out.writeOptionalVLong(maxCompletionTokens); + out.writeOptionalStringCollection(stop); + out.writeOptionalFloat(temperature); + out.writeOptionalNamedWriteable(toolChoice); + out.writeOptionalCollection(tools); + out.writeOptionalFloat(topP); + } + + public record Message(Content content, String role, @Nullable String name, @Nullable String toolCallId, List toolCalls) + implements + Writeable { + + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + Message.class.getSimpleName(), + args -> new Message((Content) args[0], (String) args[1], (String) args[2], (String) args[3], (List) args[4]) + ); + + static { + PARSER.declareField(constructorArg(), (p, c) -> parseContent(p), new ParseField("content"), ObjectParser.ValueType.VALUE_ARRAY); + PARSER.declareString(constructorArg(), new ParseField("role")); + PARSER.declareString(optionalConstructorArg(), new ParseField("name")); + PARSER.declareString(optionalConstructorArg(), new ParseField("tool_call_id")); + PARSER.declareObjectArray(optionalConstructorArg(), ToolCall.PARSER::apply, new ParseField("tool_calls")); + } + + private static Content parseContent(XContentParser parser) throws IOException { + var token = parser.currentToken(); + if (token == XContentParser.Token.START_ARRAY) { + var parsedContentObjects = XContentParserUtils.parseList(parser, (p) -> ContentObject.PARSER.apply(p, null)); + return new ContentObjects(parsedContentObjects); + } else if (token == XContentParser.Token.VALUE_STRING) { + return ContentString.of(parser); + } + + throw new XContentParseException("Expected an array start token or a value string token but found token [" + token + "]"); + } + + public Message(StreamInput in) throws IOException { + this( + in.readNamedWriteable(Content.class), + in.readString(), + in.readOptionalString(), + in.readOptionalString(), + in.readOptionalCollectionAsList(ToolCall::new) + ); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeNamedWriteable(content); + out.writeString(role); + out.writeOptionalString(name); + out.writeOptionalString(toolCallId); + out.writeOptionalCollection(toolCalls); + } + } + + public record ContentObjects(List contentObjects) implements Content, NamedWriteable { + + public static final String NAME = "content_objects"; + + public ContentObjects(StreamInput in) throws IOException { + this(in.readCollectionAsImmutableList(ContentObject::new)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeCollection(contentObjects); + } + + @Override + public String getWriteableName() { + return NAME; + } + } + + public record ContentObject(String text, String type) implements Writeable { + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ContentObject.class.getSimpleName(), + args -> new ContentObject((String) args[0], (String) args[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("text")); + PARSER.declareString(constructorArg(), new ParseField("type")); + } + + public ContentObject(StreamInput in) throws IOException { + this(in.readString(), in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(text); + out.writeString(type); + } + + public String toString() { + return text + ":" + type; + } + + } + + public record ContentString(String content) implements Content, NamedWriteable { + public static final String NAME = "content_string"; + + public static ContentString of(XContentParser parser) throws IOException { + var content = parser.text(); + return new ContentString(content); + } + + public ContentString(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(content); + } + + @Override + public String getWriteableName() { + return NAME; + } + + public String toString() { + return content; + } + } + + public record ToolCall(String id, FunctionField function, String type) implements Writeable { + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ToolCall.class.getSimpleName(), + args -> new ToolCall((String) args[0], (FunctionField) args[1], (String) args[2]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("id")); + PARSER.declareObject(constructorArg(), FunctionField.PARSER::apply, new ParseField("function")); + PARSER.declareString(constructorArg(), new ParseField("type")); + } + + public ToolCall(StreamInput in) throws IOException { + this(in.readString(), new FunctionField(in), in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(id); + function.writeTo(out); + out.writeString(type); + } + + public record FunctionField(String arguments, String name) implements Writeable { + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "tool_call_function_field", + args -> new FunctionField((String) args[0], (String) args[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("arguments")); + PARSER.declareString(constructorArg(), new ParseField("name")); + } + + public FunctionField(StreamInput in) throws IOException { + this(in.readString(), in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(arguments); + out.writeString(name); + } + } + } + + private static ToolChoice parseToolChoice(XContentParser parser) throws IOException { + var token = parser.currentToken(); + if (token == XContentParser.Token.START_OBJECT) { + return ToolChoiceObject.PARSER.apply(parser, null); + } else if (token == XContentParser.Token.VALUE_STRING) { + return ToolChoiceString.of(parser); + } + + throw new XContentParseException("Unsupported token [" + token + "]"); + } + + public sealed interface ToolChoice extends NamedWriteable permits ToolChoiceObject, ToolChoiceString {} + + public record ToolChoiceObject(String type, FunctionField function) implements ToolChoice, NamedWriteable { + + public static final String NAME = "tool_choice_object"; + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + ToolChoiceObject.class.getSimpleName(), + args -> new ToolChoiceObject((String) args[0], (FunctionField) args[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("type")); + PARSER.declareObject(constructorArg(), FunctionField.PARSER::apply, new ParseField("function")); + } + + public ToolChoiceObject(StreamInput in) throws IOException { + this(in.readString(), new FunctionField(in)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + function.writeTo(out); + } + + @Override + public String getWriteableName() { + return NAME; + } + + public record FunctionField(String name) implements Writeable { + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "tool_choice_function_field", + args -> new FunctionField((String) args[0]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("name")); + } + + public FunctionField(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(name); + } + } + } + + public record ToolChoiceString(String value) implements ToolChoice, NamedWriteable { + public static final String NAME = "tool_choice_string"; + + public static ToolChoiceString of(XContentParser parser) throws IOException { + var content = parser.text(); + return new ToolChoiceString(content); + } + + public ToolChoiceString(StreamInput in) throws IOException { + this(in.readString()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(value); + } + + @Override + public String getWriteableName() { + return NAME; + } + } + + public record Tool(String type, FunctionField function) implements Writeable { + + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + Tool.class.getSimpleName(), + args -> new Tool((String) args[0], (FunctionField) args[1]) + ); + + static { + PARSER.declareString(constructorArg(), new ParseField("type")); + PARSER.declareObject(constructorArg(), FunctionField.PARSER::apply, new ParseField("function")); + } + + public Tool(StreamInput in) throws IOException { + this(in.readString(), new FunctionField(in)); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeString(type); + function.writeTo(out); + } + + public record FunctionField( + @Nullable String description, + String name, + @Nullable Map parameters, + @Nullable Boolean strict + ) implements Writeable { + + @SuppressWarnings("unchecked") + static final ConstructingObjectParser PARSER = new ConstructingObjectParser<>( + "tool_function_field", + args -> new FunctionField((String) args[0], (String) args[1], (Map) args[2], (Boolean) args[3]) + ); + + static { + PARSER.declareString(optionalConstructorArg(), new ParseField("description")); + PARSER.declareString(constructorArg(), new ParseField("name")); + PARSER.declareObject(optionalConstructorArg(), (p, c) -> p.mapOrdered(), new ParseField("parameters")); + PARSER.declareBoolean(optionalConstructorArg(), new ParseField("strict")); + } + + public FunctionField(StreamInput in) throws IOException { + this(in.readOptionalString(), in.readString(), in.readGenericMap(), in.readOptionalBoolean()); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + out.writeOptionalString(description); + out.writeString(name); + out.writeGenericMap(parameters); + out.writeOptionalBoolean(strict); + } + } + } +} diff --git a/server/src/main/java/org/elasticsearch/monitor/os/OsService.java b/server/src/main/java/org/elasticsearch/monitor/os/OsService.java index 7609cc14c6b3..ceed2b0e41fc 100644 --- a/server/src/main/java/org/elasticsearch/monitor/os/OsService.java +++ b/server/src/main/java/org/elasticsearch/monitor/os/OsService.java @@ -25,7 +25,6 @@ public class OsService implements ReportingService { private static final Logger logger = LogManager.getLogger(OsService.class); - private final OsProbe probe; private final OsInfo info; private final SingleObjectCache osStatsCache; @@ -37,10 +36,9 @@ public class OsService implements ReportingService { ); public OsService(Settings settings) throws IOException { - this.probe = OsProbe.getInstance(); TimeValue refreshInterval = REFRESH_INTERVAL_SETTING.get(settings); - this.info = probe.osInfo(refreshInterval.millis(), EsExecutors.nodeProcessors(settings)); - this.osStatsCache = new OsStatsCache(refreshInterval, probe.osStats()); + this.info = OsProbe.getInstance().osInfo(refreshInterval.millis(), EsExecutors.nodeProcessors(settings)); + this.osStatsCache = new OsStatsCache(refreshInterval); logger.debug("using refresh_interval [{}]", refreshInterval); } @@ -53,14 +51,28 @@ public OsStats stats() { return osStatsCache.getOrRefresh(); } - private class OsStatsCache extends SingleObjectCache { - OsStatsCache(TimeValue interval, OsStats initValue) { - super(interval, initValue); + private static class OsStatsCache extends SingleObjectCache { + + private static final OsStats MISSING = new OsStats( + 0L, + new OsStats.Cpu((short) 0, new double[0]), + new OsStats.Mem(0, 0, 0), + new OsStats.Swap(0, 0), + null + ); + + OsStatsCache(TimeValue interval) { + super(interval, MISSING); } @Override protected OsStats refresh() { - return probe.osStats(); + return OsProbe.getInstance().osStats(); + } + + @Override + protected boolean needsRefresh() { + return getNoRefresh() == MISSING || super.needsRefresh(); } } } diff --git a/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java index b4f91f68b8bb..7cd7bce4db18 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/AbstractMultiClustersTestCase.java @@ -9,6 +9,8 @@ package org.elasticsearch.test; +import com.carrotsearch.randomizedtesting.RandomizedTest; + import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.elasticsearch.action.admin.cluster.remote.RemoteInfoRequest; @@ -108,6 +110,11 @@ public final void startClusters() throws Exception { MockTransportService.TestPlugin.class, getTestTransportPlugin() ); + // We are going to initialize multiple clusters concurrently, but there is a race condition around the lazy initialization of test + // groups in GroupEvaluator across multiple threads. See https://github.com/randomizedtesting/randomizedtesting/issues/311. + // Calling isNightly before parallelizing is enough to work around that issue. + @SuppressWarnings("unused") + boolean nightly = RandomizedTest.isNightly(); runInParallel(clusterAliases.size(), i -> { String clusterAlias = clusterAliases.get(i); final String clusterName = clusterAlias.equals(LOCAL_CLUSTER) ? "main-cluster" : clusterAlias; diff --git a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java index d983fc854bdf..a71f61740e17 100644 --- a/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/test/ESTestCase.java @@ -1205,10 +1205,30 @@ public static SecureString randomSecureStringOfLength(int codeUnits) { return new SecureString(randomAlpha.toCharArray()); } - public static String randomNullOrAlphaOfLength(int codeUnits) { + public static String randomAlphaOfLengthOrNull(int codeUnits) { return randomBoolean() ? null : randomAlphaOfLength(codeUnits); } + public static Long randomLongOrNull() { + return randomBoolean() ? null : randomLong(); + } + + public static Long randomPositiveLongOrNull() { + return randomBoolean() ? null : randomNonNegativeLong(); + } + + public static Integer randomIntOrNull() { + return randomBoolean() ? null : randomInt(); + } + + public static Integer randomPositiveIntOrNull() { + return randomBoolean() ? null : randomNonNegativeInt(); + } + + public static Float randomFloatOrNull() { + return randomBoolean() ? null : randomFloat(); + } + /** * Creates a valid random identifier such as node id or index name */ diff --git a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/local/DefaultLocalClusterHandle.java b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/local/DefaultLocalClusterHandle.java index eb45aacda68d..13adde1da8a6 100644 --- a/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/local/DefaultLocalClusterHandle.java +++ b/test/test-clusters/src/main/java/org/elasticsearch/test/cluster/local/DefaultLocalClusterHandle.java @@ -176,8 +176,9 @@ public long getPid(int index) { return nodes.get(index).getPid(); } + @Override public void stopNode(int index, boolean forcibly) { - nodes.get(index).stop(false); + nodes.get(index).stop(forcibly); } @Override @@ -252,9 +253,8 @@ private void writeUnicastHostsFile() { execute(() -> nodes.parallelStream().forEach(node -> { try { Path hostsFile = node.getWorkingDir().resolve("config").resolve("unicast_hosts.txt"); - if (Files.notExists(hostsFile)) { - Files.writeString(hostsFile, transportUris); - } + LOGGER.info("Writing unicast hosts file {} for node {}", hostsFile, node.getName()); + Files.writeString(hostsFile, transportUris); } catch (IOException e) { throw new UncheckedIOException("Failed to write unicast_hosts for: " + node, e); } diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java new file mode 100644 index 000000000000..e426574c52ce --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/BaseInferenceActionRequest.java @@ -0,0 +1,31 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.action.ActionRequest; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.inference.TaskType; + +import java.io.IOException; + +public abstract class BaseInferenceActionRequest extends ActionRequest { + + public BaseInferenceActionRequest() { + super(); + } + + public BaseInferenceActionRequest(StreamInput in) throws IOException { + super(in); + } + + public abstract boolean isStreaming(); + + public abstract TaskType getTaskType(); + + public abstract String getInferenceEntityId(); +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java index a19edd5a0816..f88909ba4208 100644 --- a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/InferenceAction.java @@ -10,7 +10,6 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.TransportVersion; import org.elasticsearch.TransportVersions; -import org.elasticsearch.action.ActionRequest; import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; @@ -54,7 +53,7 @@ public InferenceAction() { super(NAME); } - public static class Request extends ActionRequest { + public static class Request extends BaseInferenceActionRequest { public static final TimeValue DEFAULT_TIMEOUT = TimeValue.timeValueSeconds(30); public static final ParseField INPUT = new ParseField("input"); diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java new file mode 100644 index 000000000000..8d121463fb46 --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionAction.java @@ -0,0 +1,129 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.ActionType; +import org.elasticsearch.common.io.stream.StreamInput; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.XContentParser; + +import java.io.IOException; +import java.util.Objects; + +public class UnifiedCompletionAction extends ActionType { + public static final UnifiedCompletionAction INSTANCE = new UnifiedCompletionAction(); + public static final String NAME = "cluster:monitor/xpack/inference/unified"; + + public UnifiedCompletionAction() { + super(NAME); + } + + public static class Request extends BaseInferenceActionRequest { + public static Request parseRequest(String inferenceEntityId, TaskType taskType, TimeValue timeout, XContentParser parser) + throws IOException { + var unifiedRequest = UnifiedCompletionRequest.PARSER.apply(parser, null); + return new Request(inferenceEntityId, taskType, unifiedRequest, timeout); + } + + private final String inferenceEntityId; + private final TaskType taskType; + private final UnifiedCompletionRequest unifiedCompletionRequest; + private final TimeValue timeout; + + public Request(String inferenceEntityId, TaskType taskType, UnifiedCompletionRequest unifiedCompletionRequest, TimeValue timeout) { + this.inferenceEntityId = Objects.requireNonNull(inferenceEntityId); + this.taskType = Objects.requireNonNull(taskType); + this.unifiedCompletionRequest = Objects.requireNonNull(unifiedCompletionRequest); + this.timeout = Objects.requireNonNull(timeout); + } + + public Request(StreamInput in) throws IOException { + super(in); + this.inferenceEntityId = in.readString(); + this.taskType = TaskType.fromStream(in); + this.unifiedCompletionRequest = new UnifiedCompletionRequest(in); + this.timeout = in.readTimeValue(); + } + + public TaskType getTaskType() { + return taskType; + } + + public String getInferenceEntityId() { + return inferenceEntityId; + } + + public UnifiedCompletionRequest getUnifiedCompletionRequest() { + return unifiedCompletionRequest; + } + + /** + * The Unified API only supports streaming so we always return true here. + * @return true + */ + public boolean isStreaming() { + return true; + } + + public TimeValue getTimeout() { + return timeout; + } + + @Override + public ActionRequestValidationException validate() { + if (unifiedCompletionRequest == null || unifiedCompletionRequest.messages() == null) { + var e = new ActionRequestValidationException(); + e.addValidationError("Field [messages] cannot be null"); + return e; + } + + if (unifiedCompletionRequest.messages().isEmpty()) { + var e = new ActionRequestValidationException(); + e.addValidationError("Field [messages] cannot be an empty array"); + return e; + } + + if (taskType.isAnyOrSame(TaskType.COMPLETION) == false) { + var e = new ActionRequestValidationException(); + e.addValidationError("Field [taskType] must be [completion]"); + return e; + } + + return null; + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + super.writeTo(out); + out.writeString(inferenceEntityId); + taskType.writeTo(out); + unifiedCompletionRequest.writeTo(out); + out.writeTimeValue(timeout); + } + + @Override + public boolean equals(Object o) { + if (o == null || getClass() != o.getClass()) return false; + Request request = (Request) o; + return Objects.equals(inferenceEntityId, request.inferenceEntityId) + && taskType == request.taskType + && Objects.equals(unifiedCompletionRequest, request.unifiedCompletionRequest) + && Objects.equals(timeout, request.timeout); + } + + @Override + public int hashCode() { + return Objects.hash(inferenceEntityId, taskType, unifiedCompletionRequest, timeout); + } + } + +} diff --git a/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java new file mode 100644 index 000000000000..90038c67036c --- /dev/null +++ b/x-pack/plugin/core/src/main/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResults.java @@ -0,0 +1,329 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.collect.Iterators; +import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.common.xcontent.ChunkedToXContentHelper; +import org.elasticsearch.inference.InferenceResults; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xcontent.ToXContent; + +import java.io.IOException; +import java.util.Collections; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Flow; + +/** + * Chat Completion results that only contain a Flow.Publisher. + */ +public record StreamingUnifiedChatCompletionResults(Flow.Publisher publisher) + implements + InferenceServiceResults { + + public static final String NAME = "chat_completion_chunk"; + public static final String MODEL_FIELD = "model"; + public static final String OBJECT_FIELD = "object"; + public static final String USAGE_FIELD = "usage"; + public static final String INDEX_FIELD = "index"; + public static final String ID_FIELD = "id"; + public static final String FUNCTION_NAME_FIELD = "name"; + public static final String FUNCTION_ARGUMENTS_FIELD = "arguments"; + public static final String FUNCTION_FIELD = "function"; + public static final String CHOICES_FIELD = "choices"; + public static final String DELTA_FIELD = "delta"; + public static final String CONTENT_FIELD = "content"; + public static final String REFUSAL_FIELD = "refusal"; + public static final String ROLE_FIELD = "role"; + private static final String TOOL_CALLS_FIELD = "tool_calls"; + public static final String FINISH_REASON_FIELD = "finish_reason"; + public static final String COMPLETION_TOKENS_FIELD = "completion_tokens"; + public static final String TOTAL_TOKENS_FIELD = "total_tokens"; + public static final String PROMPT_TOKENS_FIELD = "prompt_tokens"; + public static final String TYPE_FIELD = "type"; + + @Override + public boolean isStreaming() { + return true; + } + + @Override + public List transformToCoordinationFormat() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public List transformToLegacyFormat() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public Map asMap() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public String getWriteableName() { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public void writeTo(StreamOutput out) throws IOException { + throw new UnsupportedOperationException("Not implemented"); + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + throw new UnsupportedOperationException("Not implemented"); + } + + public record Results(Deque chunks) implements ChunkedToXContent { + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + return Iterators.concat(Iterators.flatMap(chunks.iterator(), c -> c.toXContentChunked(params))); + } + } + + public static class ChatCompletionChunk implements ChunkedToXContent { + private final String id; + + public String getId() { + return id; + } + + public List getChoices() { + return choices; + } + + public String getModel() { + return model; + } + + public String getObject() { + return object; + } + + public Usage getUsage() { + return usage; + } + + private final List choices; + private final String model; + private final String object; + private final ChatCompletionChunk.Usage usage; + + public ChatCompletionChunk(String id, List choices, String model, String object, ChatCompletionChunk.Usage usage) { + this.id = id; + this.choices = choices; + this.model = model; + this.object = object; + this.usage = usage; + } + + @Override + public Iterator toXContentChunked(ToXContent.Params params) { + + Iterator choicesIterator = Collections.emptyIterator(); + if (choices != null) { + choicesIterator = Iterators.concat( + ChunkedToXContentHelper.startArray(CHOICES_FIELD), + Iterators.flatMap(choices.iterator(), c -> c.toXContentChunked(params)), + ChunkedToXContentHelper.endArray() + ); + } + + Iterator usageIterator = Collections.emptyIterator(); + if (usage != null) { + usageIterator = Iterators.concat( + ChunkedToXContentHelper.startObject(USAGE_FIELD), + ChunkedToXContentHelper.field(COMPLETION_TOKENS_FIELD, usage.completionTokens()), + ChunkedToXContentHelper.field(PROMPT_TOKENS_FIELD, usage.promptTokens()), + ChunkedToXContentHelper.field(TOTAL_TOKENS_FIELD, usage.totalTokens()), + ChunkedToXContentHelper.endObject() + ); + } + + return Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field(ID_FIELD, id), + choicesIterator, + ChunkedToXContentHelper.field(MODEL_FIELD, model), + ChunkedToXContentHelper.field(OBJECT_FIELD, object), + usageIterator, + ChunkedToXContentHelper.endObject() + ); + } + + public record Choice(ChatCompletionChunk.Choice.Delta delta, String finishReason, int index) { + + /* + choices: Array<{ + delta: { ... }; + finish_reason: string | null; + index: number; + }>; + */ + public Iterator toXContentChunked(ToXContent.Params params) { + return Iterators.concat( + ChunkedToXContentHelper.startObject(), + delta.toXContentChunked(params), + ChunkedToXContentHelper.optionalField(FINISH_REASON_FIELD, finishReason), + ChunkedToXContentHelper.field(INDEX_FIELD, index), + ChunkedToXContentHelper.endObject() + ); + } + + public static class Delta { + private final String content; + private final String refusal; + private final String role; + private List toolCalls; + + public Delta(String content, String refusal, String role, List toolCalls) { + this.content = content; + this.refusal = refusal; + this.role = role; + this.toolCalls = toolCalls; + } + + /* + delta: { + content?: string | null; + refusal?: string | null; + role?: 'system' | 'user' | 'assistant' | 'tool'; + tool_calls?: Array<{ ... }>; + }; + */ + public Iterator toXContentChunked(ToXContent.Params params) { + var xContent = Iterators.concat( + ChunkedToXContentHelper.startObject(DELTA_FIELD), + ChunkedToXContentHelper.optionalField(CONTENT_FIELD, content), + ChunkedToXContentHelper.optionalField(REFUSAL_FIELD, refusal), + ChunkedToXContentHelper.optionalField(ROLE_FIELD, role) + ); + + if (toolCalls != null && toolCalls.isEmpty() == false) { + xContent = Iterators.concat( + xContent, + ChunkedToXContentHelper.startArray(TOOL_CALLS_FIELD), + Iterators.flatMap(toolCalls.iterator(), t -> t.toXContentChunked(params)), + ChunkedToXContentHelper.endArray() + ); + } + xContent = Iterators.concat(xContent, ChunkedToXContentHelper.endObject()); + return xContent; + + } + + public String getContent() { + return content; + } + + public String getRefusal() { + return refusal; + } + + public String getRole() { + return role; + } + + public List getToolCalls() { + return toolCalls; + } + + public static class ToolCall { + private final int index; + private final String id; + public ChatCompletionChunk.Choice.Delta.ToolCall.Function function; + private final String type; + + public ToolCall(int index, String id, ChatCompletionChunk.Choice.Delta.ToolCall.Function function, String type) { + this.index = index; + this.id = id; + this.function = function; + this.type = type; + } + + public int getIndex() { + return index; + } + + public String getId() { + return id; + } + + public ChatCompletionChunk.Choice.Delta.ToolCall.Function getFunction() { + return function; + } + + public String getType() { + return type; + } + + /* + index: number; + id?: string; + function?: { + arguments?: string; + name?: string; + }; + type?: 'function'; + */ + public Iterator toXContentChunked(ToXContent.Params params) { + var content = Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field(INDEX_FIELD, index), + ChunkedToXContentHelper.optionalField(ID_FIELD, id) + ); + + if (function != null) { + content = Iterators.concat( + content, + ChunkedToXContentHelper.startObject(FUNCTION_FIELD), + ChunkedToXContentHelper.optionalField(FUNCTION_ARGUMENTS_FIELD, function.getArguments()), + ChunkedToXContentHelper.optionalField(FUNCTION_NAME_FIELD, function.getName()), + ChunkedToXContentHelper.endObject() + ); + } + + content = Iterators.concat( + content, + ChunkedToXContentHelper.field(TYPE_FIELD, type), + ChunkedToXContentHelper.endObject() + ); + return content; + } + + public static class Function { + private final String arguments; + private final String name; + + public Function(String arguments, String name) { + this.arguments = arguments; + this.name = name; + } + + public String getArguments() { + return arguments; + } + + public String getName() { + return name; + } + } + } + } + } + + public record Usage(int completionTokens, int promptTokens, int totalTokens) {} + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java index a9ca5e6da872..01c0ff88be22 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/InferenceActionRequestTests.java @@ -41,8 +41,7 @@ protected InferenceAction.Request createTestInstance() { return new InferenceAction.Request( randomFrom(TaskType.values()), randomAlphaOfLength(6), - // null, - randomNullOrAlphaOfLength(10), + randomAlphaOfLengthOrNull(10), randomList(1, 5, () -> randomAlphaOfLength(8)), randomMap(0, 3, () -> new Tuple<>(randomAlphaOfLength(4), randomAlphaOfLength(4))), randomFrom(InputType.values()), diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java new file mode 100644 index 000000000000..1872ac3caa23 --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionActionRequestTests.java @@ -0,0 +1,97 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.List; + +import static org.hamcrest.Matchers.is; + +public class UnifiedCompletionActionRequestTests extends AbstractBWCWireSerializationTestCase { + + public void testValidation_ReturnsException_When_UnifiedCompletionRequestMessage_Is_Null() { + var request = new UnifiedCompletionAction.Request( + "inference_id", + TaskType.COMPLETION, + UnifiedCompletionRequest.of(null), + TimeValue.timeValueSeconds(10) + ); + var exception = request.validate(); + assertThat(exception.getMessage(), is("Validation Failed: 1: Field [messages] cannot be null;")); + } + + public void testValidation_ReturnsException_When_UnifiedCompletionRequest_Is_EmptyArray() { + var request = new UnifiedCompletionAction.Request( + "inference_id", + TaskType.COMPLETION, + UnifiedCompletionRequest.of(List.of()), + TimeValue.timeValueSeconds(10) + ); + var exception = request.validate(); + assertThat(exception.getMessage(), is("Validation Failed: 1: Field [messages] cannot be an empty array;")); + } + + public void testValidation_ReturnsException_When_TaskType_IsNot_Completion() { + var request = new UnifiedCompletionAction.Request( + "inference_id", + TaskType.SPARSE_EMBEDDING, + UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())), + TimeValue.timeValueSeconds(10) + ); + var exception = request.validate(); + assertThat(exception.getMessage(), is("Validation Failed: 1: Field [taskType] must be [completion];")); + } + + public void testValidation_ReturnsNull_When_TaskType_IsAny() { + var request = new UnifiedCompletionAction.Request( + "inference_id", + TaskType.ANY, + UnifiedCompletionRequest.of(List.of(UnifiedCompletionRequestTests.randomMessage())), + TimeValue.timeValueSeconds(10) + ); + assertNull(request.validate()); + } + + @Override + protected UnifiedCompletionAction.Request mutateInstanceForVersion(UnifiedCompletionAction.Request instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return UnifiedCompletionAction.Request::new; + } + + @Override + protected UnifiedCompletionAction.Request createTestInstance() { + return new UnifiedCompletionAction.Request( + randomAlphaOfLength(10), + randomFrom(TaskType.values()), + UnifiedCompletionRequestTests.randomUnifiedCompletionRequest(), + TimeValue.timeValueMillis(randomLongBetween(1, 2048)) + ); + } + + @Override + protected UnifiedCompletionAction.Request mutateInstance(UnifiedCompletionAction.Request instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(UnifiedCompletionRequest.getNamedWriteables()); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java new file mode 100644 index 000000000000..47a0814a584b --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/action/UnifiedCompletionRequestTests.java @@ -0,0 +1,293 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.core.inference.action; + +import org.elasticsearch.TransportVersion; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.ml.AbstractBWCWireSerializationTestCase; + +import java.io.IOException; +import java.util.List; +import java.util.Map; + +import static org.hamcrest.Matchers.is; + +public class UnifiedCompletionRequestTests extends AbstractBWCWireSerializationTestCase { + + public void testParseAllFields() throws IOException { + String requestJson = """ + { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + { + "text": "some text", + "type": "string" + } + ], + "name": "a name", + "tool_call_id": "100", + "tool_calls": [ + { + "id": "call_62136354", + "type": "function", + "function": { + "arguments": "{'order_id': 'order_12345'}", + "name": "get_delivery_date" + } + } + ] + } + ], + "max_completion_tokens": 100, + "stop": ["stop"], + "temperature": 0.1, + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object" + } + } + } + ], + "tool_choice": { + "type": "function", + "function": { + "name": "some function" + } + }, + "top_p": 0.2 + } + """; + + try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) { + var request = UnifiedCompletionRequest.PARSER.apply(parser, null); + var expected = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentObjects( + List.of(new UnifiedCompletionRequest.ContentObject("some text", "string")) + ), + "user", + "a name", + "100", + List.of( + new UnifiedCompletionRequest.ToolCall( + "call_62136354", + new UnifiedCompletionRequest.ToolCall.FunctionField("{'order_id': 'order_12345'}", "get_delivery_date"), + "function" + ) + ) + ) + ), + "gpt-4o", + 100L, + List.of("stop"), + 0.1F, + new UnifiedCompletionRequest.ToolChoiceObject( + "function", + new UnifiedCompletionRequest.ToolChoiceObject.FunctionField("some function") + ), + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current weather in a given location", + "get_current_weather", + Map.of("type", "object"), + null + ) + ) + ), + 0.2F + ); + + assertThat(request, is(expected)); + } + } + + public void testParsing() throws IOException { + String requestJson = """ + { + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": "What is the weather like in Boston today?" + } + ], + "stop": "none", + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object" + } + } + } + ], + "tool_choice": "auto" + } + """; + + try (var parser = createParser(JsonXContent.jsonXContent, requestJson)) { + var request = UnifiedCompletionRequest.PARSER.apply(parser, null); + var expected = new UnifiedCompletionRequest( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("What is the weather like in Boston today?"), + "user", + null, + null, + null + ) + ), + "gpt-4o", + null, + List.of("none"), + null, + new UnifiedCompletionRequest.ToolChoiceString("auto"), + List.of( + new UnifiedCompletionRequest.Tool( + "function", + new UnifiedCompletionRequest.Tool.FunctionField( + "Get the current weather in a given location", + "get_current_weather", + Map.of("type", "object"), + null + ) + ) + ), + null + ); + + assertThat(request, is(expected)); + } + } + + public static UnifiedCompletionRequest randomUnifiedCompletionRequest() { + return new UnifiedCompletionRequest( + randomList(5, UnifiedCompletionRequestTests::randomMessage), + randomAlphaOfLengthOrNull(10), + randomPositiveLongOrNull(), + randomStopOrNull(), + randomFloatOrNull(), + randomToolChoiceOrNull(), + randomToolListOrNull(), + randomFloatOrNull() + ); + } + + public static UnifiedCompletionRequest.Message randomMessage() { + return new UnifiedCompletionRequest.Message( + randomContent(), + randomAlphaOfLength(10), + randomAlphaOfLengthOrNull(10), + randomAlphaOfLengthOrNull(10), + randomToolCallListOrNull() + ); + } + + public static UnifiedCompletionRequest.Content randomContent() { + return randomBoolean() + ? new UnifiedCompletionRequest.ContentString(randomAlphaOfLength(10)) + : new UnifiedCompletionRequest.ContentObjects(randomList(10, UnifiedCompletionRequestTests::randomContentObject)); + } + + public static UnifiedCompletionRequest.ContentObject randomContentObject() { + return new UnifiedCompletionRequest.ContentObject(randomAlphaOfLength(10), randomAlphaOfLength(10)); + } + + public static List randomToolCallListOrNull() { + return randomBoolean() ? randomList(10, UnifiedCompletionRequestTests::randomToolCall) : null; + } + + public static UnifiedCompletionRequest.ToolCall randomToolCall() { + return new UnifiedCompletionRequest.ToolCall(randomAlphaOfLength(10), randomToolCallFunctionField(), randomAlphaOfLength(10)); + } + + public static UnifiedCompletionRequest.ToolCall.FunctionField randomToolCallFunctionField() { + return new UnifiedCompletionRequest.ToolCall.FunctionField(randomAlphaOfLength(10), randomAlphaOfLength(10)); + } + + public static List randomStopOrNull() { + return randomBoolean() ? randomStop() : null; + } + + public static List randomStop() { + return randomList(5, () -> randomAlphaOfLength(10)); + } + + public static UnifiedCompletionRequest.ToolChoice randomToolChoiceOrNull() { + return randomBoolean() ? randomToolChoice() : null; + } + + public static UnifiedCompletionRequest.ToolChoice randomToolChoice() { + return randomBoolean() + ? new UnifiedCompletionRequest.ToolChoiceString(randomAlphaOfLength(10)) + : new UnifiedCompletionRequest.ToolChoiceObject(randomAlphaOfLength(10), randomToolChoiceObjectFunctionField()); + } + + public static UnifiedCompletionRequest.ToolChoiceObject.FunctionField randomToolChoiceObjectFunctionField() { + return new UnifiedCompletionRequest.ToolChoiceObject.FunctionField(randomAlphaOfLength(10)); + } + + public static List randomToolListOrNull() { + return randomBoolean() ? randomList(10, UnifiedCompletionRequestTests::randomTool) : null; + } + + public static UnifiedCompletionRequest.Tool randomTool() { + return new UnifiedCompletionRequest.Tool(randomAlphaOfLength(10), randomToolFunctionField()); + } + + public static UnifiedCompletionRequest.Tool.FunctionField randomToolFunctionField() { + return new UnifiedCompletionRequest.Tool.FunctionField( + randomAlphaOfLengthOrNull(10), + randomAlphaOfLength(10), + null, + randomOptionalBoolean() + ); + } + + @Override + protected UnifiedCompletionRequest mutateInstanceForVersion(UnifiedCompletionRequest instance, TransportVersion version) { + return instance; + } + + @Override + protected Writeable.Reader instanceReader() { + return UnifiedCompletionRequest::new; + } + + @Override + protected UnifiedCompletionRequest createTestInstance() { + return randomUnifiedCompletionRequest(); + } + + @Override + protected UnifiedCompletionRequest mutateInstance(UnifiedCompletionRequest instance) throws IOException { + return randomValueOtherThan(instance, this::createTestInstance); + } + + @Override + protected NamedWriteableRegistry getNamedWriteableRegistry() { + return new NamedWriteableRegistry(UnifiedCompletionRequest.getNamedWriteables()); + } +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java new file mode 100644 index 000000000000..a8f569dbef9d --- /dev/null +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/inference/results/StreamingUnifiedChatCompletionResultsTests.java @@ -0,0 +1,198 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + * + * this file was contributed to by a generative AI + */ + +package org.elasticsearch.xpack.core.inference.results; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.List; + +public class StreamingUnifiedChatCompletionResultsTests extends ESTestCase { + + public void testResults_toXContentChunked() throws IOException { + String expected = """ + { + "id": "chunk1", + "choices": [ + { + "delta": { + "content": "example_content", + "refusal": "example_refusal", + "role": "assistant", + "tool_calls": [ + { + "index": 1, + "id": "tool1", + "function": { + "arguments": "example_arguments", + "name": "example_function" + }, + "type": "function" + } + ] + }, + "finish_reason": "example_reason", + "index": 0 + } + ], + "model": "example_model", + "object": "example_object", + "usage": { + "completion_tokens": 10, + "prompt_tokens": 5, + "total_tokens": 15 + } + } + """; + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + "chunk1", + List.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + "example_content", + "example_refusal", + "assistant", + List.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + 1, + "tool1", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + "example_arguments", + "example_function" + ), + "function" + ) + ) + ), + "example_reason", + 0 + ) + ), + "example_model", + "example_object", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage(10, 5, 15) + ); + + Deque deque = new ArrayDeque<>(); + deque.add(chunk); + StreamingUnifiedChatCompletionResults.Results results = new StreamingUnifiedChatCompletionResults.Results(deque); + XContentBuilder builder = JsonXContent.contentBuilder(); + results.toXContentChunked(null).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); + } + + public void testChoiceToXContentChunked() throws IOException { + String expected = """ + { + "delta": { + "content": "example_content", + "refusal": "example_refusal", + "role": "assistant", + "tool_calls": [ + { + "index": 1, + "id": "tool1", + "function": { + "arguments": "example_arguments", + "name": "example_function" + }, + "type": "function" + } + ] + }, + "finish_reason": "example_reason", + "index": 0 + } + """; + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + "example_content", + "example_refusal", + "assistant", + List.of( + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + 1, + "tool1", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + "example_arguments", + "example_function" + ), + "function" + ) + ) + ), + "example_reason", + 0 + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + choice.toXContentChunked(null).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); + } + + public void testToolCallToXContentChunked() throws IOException { + String expected = """ + { + "index": 1, + "id": "tool1", + "function": { + "arguments": "example_arguments", + "name": "example_function" + }, + "type": "function" + } + """; + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + 1, + "tool1", + new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + "example_arguments", + "example_function" + ), + "function" + ); + + XContentBuilder builder = JsonXContent.contentBuilder(); + toolCall.toXContentChunked(null).forEachRemaining(xContent -> { + try { + xContent.toXContent(builder, null); + } catch (IOException e) { + throw new RuntimeException(e); + } + }); + + assertEquals(expected.replaceAll("\\s+", ""), Strings.toString(builder.prettyPrint()).trim()); + } + +} diff --git a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java index 17579fd6368c..eeffa1db5485 100644 --- a/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java +++ b/x-pack/plugin/core/src/test/java/org/elasticsearch/xpack/core/security/authz/store/ReservedRolesStoreTests.java @@ -4175,6 +4175,7 @@ public void testInferenceUserRole() { assertTrue(role.cluster().check("cluster:monitor/xpack/inference", request, authentication)); assertTrue(role.cluster().check("cluster:monitor/xpack/inference/get", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/inference/put", request, authentication)); + assertTrue(role.cluster().check("cluster:monitor/xpack/inference/unified", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/inference/delete", request, authentication)); assertTrue(role.cluster().check("cluster:monitor/xpack/ml/trained_models/deployment/infer", request, authentication)); assertFalse(role.cluster().check("cluster:admin/xpack/ml/trained_models/deployment/start", request, authentication)); diff --git a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expressions.java b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expressions.java index 8baffbf887e4..4e4338aad370 100644 --- a/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expressions.java +++ b/x-pack/plugin/esql-core/src/main/java/org/elasticsearch/xpack/esql/core/expression/Expressions.java @@ -132,8 +132,16 @@ public static String name(Expression e) { return e instanceof NamedExpression ne ? ne.name() : e.sourceText(); } - public static boolean isNull(Expression e) { - return e.dataType() == DataType.NULL || (e.foldable() && e.fold() == null); + /** + * Is this {@linkplain Expression} guaranteed to have + * only the {@code null} value. {@linkplain Expression}s that + * {@link Expression#fold()} to {@code null} may + * return {@code false} here, but should eventually be folded + * into a {@link Literal} containing {@code null} which will return + * {@code true} from here. + */ + public static boolean isGuaranteedNull(Expression e) { + return e.dataType() == DataType.NULL || (e instanceof Literal lit && lit.value() == null); } public static List names(Collection e) { diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/EsqlRefCountingListener.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/EsqlRefCountingListener.java new file mode 100644 index 000000000000..69df0fb8ceff --- /dev/null +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/EsqlRefCountingListener.java @@ -0,0 +1,47 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.compute; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.RefCountingRunnable; +import org.elasticsearch.compute.operator.FailureCollector; +import org.elasticsearch.core.Releasable; + +/** + * Similar to {@link org.elasticsearch.action.support.RefCountingListener}, + * but prefers non-task-cancelled exceptions over task-cancelled ones as they are more useful for diagnosing issues. + * @see FailureCollector + */ +public final class EsqlRefCountingListener implements Releasable { + private final FailureCollector failureCollector; + private final RefCountingRunnable refs; + + public EsqlRefCountingListener(ActionListener delegate) { + this.failureCollector = new FailureCollector(); + this.refs = new RefCountingRunnable(() -> { + Exception error = failureCollector.getFailure(); + if (error != null) { + delegate.onFailure(error); + } else { + delegate.onResponse(null); + } + }); + } + + public ActionListener acquire() { + return refs.acquireListener().delegateResponse((l, e) -> { + failureCollector.unwrapAndCollect(e); + l.onFailure(e); + }); + } + + @Override + public void close() { + refs.close(); + } +} diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FailureCollector.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FailureCollector.java index 943ba4dc1f4f..337075edbdcf 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FailureCollector.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/FailureCollector.java @@ -13,9 +13,8 @@ import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.transport.TransportException; -import java.util.List; import java.util.Queue; -import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.Semaphore; /** * {@code FailureCollector} is responsible for collecting exceptions that occur in the compute engine. @@ -26,12 +25,11 @@ */ public final class FailureCollector { private final Queue cancelledExceptions = ConcurrentCollections.newQueue(); - private final AtomicInteger cancelledExceptionsCount = new AtomicInteger(); + private final Semaphore cancelledExceptionsPermits; private final Queue nonCancelledExceptions = ConcurrentCollections.newQueue(); - private final AtomicInteger nonCancelledExceptionsCount = new AtomicInteger(); + private final Semaphore nonCancelledExceptionsPermits; - private final int maxExceptions; private volatile boolean hasFailure = false; private Exception finalFailure = null; @@ -43,7 +41,8 @@ public FailureCollector(int maxExceptions) { if (maxExceptions <= 0) { throw new IllegalArgumentException("maxExceptions must be at least one"); } - this.maxExceptions = maxExceptions; + this.cancelledExceptionsPermits = new Semaphore(maxExceptions); + this.nonCancelledExceptionsPermits = new Semaphore(maxExceptions); } private static Exception unwrapTransportException(TransportException te) { @@ -60,13 +59,12 @@ private static Exception unwrapTransportException(TransportException te) { public void unwrapAndCollect(Exception e) { e = e instanceof TransportException te ? unwrapTransportException(te) : e; if (ExceptionsHelper.unwrap(e, TaskCancelledException.class) != null) { - if (cancelledExceptionsCount.incrementAndGet() <= maxExceptions) { + if (nonCancelledExceptions.isEmpty() && cancelledExceptionsPermits.tryAcquire()) { cancelledExceptions.add(e); } - } else { - if (nonCancelledExceptionsCount.incrementAndGet() <= maxExceptions) { - nonCancelledExceptions.add(e); - } + } else if (nonCancelledExceptionsPermits.tryAcquire()) { + nonCancelledExceptions.add(e); + cancelledExceptions.clear(); } hasFailure = true; } @@ -99,20 +97,22 @@ public Exception getFailure() { private Exception buildFailure() { assert hasFailure; assert Thread.holdsLock(this); - int total = 0; Exception first = null; - for (var exceptions : List.of(nonCancelledExceptions, cancelledExceptions)) { - for (Exception e : exceptions) { - if (first == null) { - first = e; - total++; - } else if (first != e) { - first.addSuppressed(e); - total++; - } - if (total >= maxExceptions) { - return first; - } + for (Exception e : nonCancelledExceptions) { + if (first == null) { + first = e; + } else if (first != e) { + first.addSuppressed(e); + } + } + if (first != null) { + return first; + } + for (Exception e : cancelledExceptions) { + if (first == null) { + first = e; + } else if (first != e) { + first.addSuppressed(e); } } assert first != null; diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java index 00c68c4f48e8..62cc4daf5fde 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeService.java @@ -14,6 +14,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionListenerResponseHandler; import org.elasticsearch.action.support.ChannelActionListener; +import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.common.component.AbstractLifecycleComponent; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; @@ -23,6 +24,7 @@ import org.elasticsearch.common.util.concurrent.EsRejectedExecutionException; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.BlockStreamInput; +import org.elasticsearch.compute.data.Page; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.TimeValue; import org.elasticsearch.tasks.CancellableTask; @@ -40,10 +42,11 @@ import java.io.IOException; import java.util.Map; +import java.util.Objects; import java.util.Set; import java.util.concurrent.Executor; -import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; /** * {@link ExchangeService} is responsible for exchanging pages between exchange sinks and sources on the same or different nodes. @@ -293,7 +296,7 @@ static final class TransportRemoteSink implements RemoteSink { final Executor responseExecutor; final AtomicLong estimatedPageSizeInBytes = new AtomicLong(0L); - final AtomicBoolean finished = new AtomicBoolean(false); + final AtomicReference> completionListenerRef = new AtomicReference<>(null); TransportRemoteSink( TransportService transportService, @@ -318,13 +321,14 @@ public void fetchPageAsync(boolean allSourcesFinished, ActionListener completionListener = completionListenerRef.get(); + if (completionListener != null) { + completionListener.addListener(listener.map(unused -> new ExchangeResponse(blockFactory, null, true))); return; } doFetchPageAsync(false, ActionListener.wrap(r -> { if (r.finished()) { - finished.set(true); + completionListenerRef.compareAndSet(null, SubscribableListener.newSucceeded(null)); } listener.onResponse(r); }, e -> close(ActionListener.running(() -> listener.onFailure(e))))); @@ -356,10 +360,19 @@ private void doFetchPageAsync(boolean allSourcesFinished, ActionListener listener) { - if (finished.compareAndSet(false, true)) { - doFetchPageAsync(true, listener.delegateFailure((l, unused) -> l.onResponse(null))); - } else { - listener.onResponse(null); + final SubscribableListener candidate = new SubscribableListener<>(); + final SubscribableListener actual = completionListenerRef.updateAndGet( + curr -> Objects.requireNonNullElse(curr, candidate) + ); + actual.addListener(listener); + if (candidate == actual) { + doFetchPageAsync(true, ActionListener.wrap(r -> { + final Page page = r.takePage(); + if (page != null) { + page.releaseBlocks(); + } + candidate.onResponse(null); + }, e -> candidate.onResponse(null))); } } } diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java index 375016a5d51d..aa722695b841 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java @@ -9,15 +9,18 @@ import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; -import org.elasticsearch.action.support.RefCountingListener; +import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.common.util.concurrent.AbstractRunnable; +import org.elasticsearch.common.util.concurrent.ConcurrentCollections; +import org.elasticsearch.compute.EsqlRefCountingListener; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.FailureCollector; import org.elasticsearch.compute.operator.IsBlockedResult; import org.elasticsearch.core.Releasable; import java.util.List; +import java.util.Map; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicInteger; @@ -40,6 +43,9 @@ public final class ExchangeSourceHandler { // The final failure collected will be notified to callers via the {@code completionListener}. private final FailureCollector failure = new FailureCollector(); + private final AtomicInteger nextSinkId = new AtomicInteger(); + private final Map remoteSinks = ConcurrentCollections.newConcurrentMap(); + /** * Creates a new ExchangeSourceHandler. * @@ -52,22 +58,25 @@ public ExchangeSourceHandler(int maxBufferSize, Executor fetchExecutor, ActionLi this.buffer = new ExchangeBuffer(maxBufferSize); this.fetchExecutor = fetchExecutor; this.outstandingSinks = new PendingInstances(() -> buffer.finish(false)); - this.outstandingSources = new PendingInstances(() -> buffer.finish(true)); + final PendingInstances closingSinks = new PendingInstances(() -> {}); + closingSinks.trackNewInstance(); + this.outstandingSources = new PendingInstances(() -> finishEarly(true, ActionListener.running(closingSinks::finishInstance))); buffer.addCompletionListener(ActionListener.running(() -> { - final ActionListener listener = ActionListener.assertAtLeastOnce(completionListener).delegateFailure((l, unused) -> { + final ActionListener listener = ActionListener.assertAtLeastOnce(completionListener); + try (RefCountingRunnable refs = new RefCountingRunnable(() -> { final Exception e = failure.getFailure(); if (e != null) { - l.onFailure(e); + listener.onFailure(e); } else { - l.onResponse(null); + listener.onResponse(null); } - }); - try (RefCountingListener refs = new RefCountingListener(listener)) { + })) { + closingSinks.completion.addListener(refs.acquireListener()); for (PendingInstances pending : List.of(outstandingSinks, outstandingSources)) { // Create an outstanding instance and then finish to complete the completionListener // if we haven't registered any instances of exchange sinks or exchange sources before. pending.trackNewInstance(); - pending.completion.addListener(refs.acquire()); + pending.completion.addListener(refs.acquireListener()); pending.finishInstance(); } } @@ -256,7 +265,11 @@ void onSinkComplete() { * @see ExchangeSinkHandler#fetchPageAsync(boolean, ActionListener) */ public void addRemoteSink(RemoteSink remoteSink, boolean failFast, int instances, ActionListener listener) { - final ActionListener sinkListener = ActionListener.assertAtLeastOnce(ActionListener.notifyOnce(listener)); + final int sinkId = nextSinkId.incrementAndGet(); + remoteSinks.put(sinkId, remoteSink); + final ActionListener sinkListener = ActionListener.assertAtLeastOnce( + ActionListener.notifyOnce(ActionListener.runBefore(listener, () -> remoteSinks.remove(sinkId))) + ); fetchExecutor.execute(new AbstractRunnable() { @Override public void onFailure(Exception e) { @@ -269,7 +282,7 @@ public void onFailure(Exception e) { @Override protected void doRun() { - try (RefCountingListener refs = new RefCountingListener(sinkListener)) { + try (EsqlRefCountingListener refs = new EsqlRefCountingListener(sinkListener)) { for (int i = 0; i < instances; i++) { var fetcher = new RemoteSinkFetcher(remoteSink, failFast, refs.acquire()); fetcher.fetchPage(); @@ -290,6 +303,22 @@ public Releasable addEmptySink() { return outstandingSinks::finishInstance; } + /** + * Gracefully terminates the exchange source early by instructing all remote exchange sinks to stop their computations. + * This can happen when the exchange source has accumulated enough data (e.g., reaching the LIMIT) or when users want to + * see the current result immediately. + * + * @param drainingPages whether to discard pages already fetched in the exchange + */ + public void finishEarly(boolean drainingPages, ActionListener listener) { + buffer.finish(drainingPages); + try (EsqlRefCountingListener refs = new EsqlRefCountingListener(listener)) { + for (RemoteSink remoteSink : remoteSinks.values()) { + remoteSink.close(refs.acquire()); + } + } + } + private static class PendingInstances { private final AtomicInteger instances = new AtomicInteger(); private final SubscribableListener completion = new SubscribableListener<>(); diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/RemoteSink.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/RemoteSink.java index aaa937ef17c0..63b5d324ce85 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/RemoteSink.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/RemoteSink.java @@ -8,6 +8,7 @@ package org.elasticsearch.compute.operator.exchange; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.compute.data.Page; public interface RemoteSink { @@ -15,11 +16,11 @@ public interface RemoteSink { default void close(ActionListener listener) { fetchPageAsync(true, listener.delegateFailure((l, r) -> { - try { - r.close(); - } finally { - l.onResponse(null); + final Page page = r.takePage(); + if (page != null) { + page.releaseBlocks(); } + l.onResponse(null); })); } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FailureCollectorTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FailureCollectorTests.java index 637cbe8892b3..5fec82b32dda 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FailureCollectorTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/FailureCollectorTests.java @@ -7,6 +7,7 @@ package org.elasticsearch.compute.operator; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.cluster.node.DiscoveryNodeUtils; import org.elasticsearch.common.Randomness; import org.elasticsearch.common.breaker.CircuitBreaker; @@ -86,6 +87,14 @@ public void testCollect() throws Exception { assertNotNull(failure); assertThat(failure, Matchers.in(nonCancelledExceptions)); assertThat(failure.getSuppressed().length, lessThan(maxExceptions)); + assertTrue( + "cancellation exceptions must be ignored", + ExceptionsHelper.unwrapCausesAndSuppressed(failure, t -> t instanceof TaskCancelledException).isEmpty() + ); + assertTrue( + "remote transport exception must be unwrapped", + ExceptionsHelper.unwrapCausesAndSuppressed(failure, t -> t instanceof TransportException).isEmpty() + ); } public void testEmpty() { diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java index fc6c850ba187..8f7532b582bc 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java @@ -55,7 +55,9 @@ import java.util.ArrayList; import java.util.Collections; import java.util.List; +import java.util.Queue; import java.util.Set; +import java.util.concurrent.Semaphore; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicInteger; @@ -421,7 +423,7 @@ public void testExchangeSourceContinueOnFailure() { } } - public void testEarlyTerminate() { + public void testClosingSinks() { BlockFactory blockFactory = blockFactory(); IntBlock block1 = blockFactory.newConstantIntBlockWith(1, 2); IntBlock block2 = blockFactory.newConstantIntBlockWith(1, 2); @@ -441,6 +443,57 @@ public void testEarlyTerminate() { assertTrue(sink.isFinished()); } + public void testFinishEarly() throws Exception { + ExchangeSourceHandler sourceHandler = new ExchangeSourceHandler(20, threadPool.generic(), ActionListener.noop()); + Semaphore permits = new Semaphore(between(1, 5)); + BlockFactory blockFactory = blockFactory(); + Queue pages = ConcurrentCollections.newQueue(); + ExchangeSource exchangeSource = sourceHandler.createExchangeSource(); + AtomicBoolean sinkClosed = new AtomicBoolean(); + PlainActionFuture sinkCompleted = new PlainActionFuture<>(); + sourceHandler.addRemoteSink((allSourcesFinished, listener) -> { + if (allSourcesFinished) { + sinkClosed.set(true); + permits.release(10); + listener.onResponse(new ExchangeResponse(blockFactory, null, sinkClosed.get())); + } else { + try { + if (permits.tryAcquire(between(0, 100), TimeUnit.MICROSECONDS)) { + boolean closed = sinkClosed.get(); + final Page page; + if (closed) { + page = new Page(blockFactory.newConstantIntBlockWith(1, 1)); + pages.add(page); + } else { + page = null; + } + listener.onResponse(new ExchangeResponse(blockFactory, page, closed)); + } else { + listener.onResponse(new ExchangeResponse(blockFactory, null, sinkClosed.get())); + } + } catch (Exception e) { + throw new AssertionError(e); + } + } + }, false, between(1, 3), sinkCompleted); + threadPool.schedule( + () -> sourceHandler.finishEarly(randomBoolean(), ActionListener.noop()), + TimeValue.timeValueMillis(between(0, 10)), + threadPool.generic() + ); + sinkCompleted.actionGet(); + Page p; + while ((p = exchangeSource.pollPage()) != null) { + assertSame(p, pages.poll()); + p.releaseBlocks(); + } + while ((p = pages.poll()) != null) { + p.releaseBlocks(); + } + assertTrue(exchangeSource.isFinished()); + exchangeSource.finish(); + } + public void testConcurrentWithTransportActions() { MockTransportService node0 = newTransportService(); ExchangeService exchange0 = new ExchangeService(Settings.EMPTY, threadPool, ESQL_TEST_EXECUTOR, blockFactory()); diff --git a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java index ec9af33dd669..5535e801b1b0 100644 --- a/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java +++ b/x-pack/plugin/esql/qa/testFixtures/src/main/java/org/elasticsearch/xpack/esql/EsqlTestUtils.java @@ -10,6 +10,7 @@ import org.apache.lucene.document.InetAddressPoint; import org.apache.lucene.sandbox.document.HalfFloatPoint; import org.apache.lucene.util.BytesRef; +import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.common.Strings; import org.elasticsearch.common.breaker.CircuitBreaker; import org.elasticsearch.common.breaker.NoopCircuitBreaker; @@ -30,7 +31,9 @@ import org.elasticsearch.geo.ShapeTestUtils; import org.elasticsearch.index.IndexMode; import org.elasticsearch.license.XPackLicenseState; +import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.RemoteTransportException; import org.elasticsearch.xcontent.json.JsonXContent; import org.elasticsearch.xpack.esql.action.EsqlQueryResponse; import org.elasticsearch.xpack.esql.analysis.EnrichResolution; @@ -129,6 +132,8 @@ import static org.elasticsearch.xpack.esql.parser.ParserUtils.ParamClassification.PATTERN; import static org.elasticsearch.xpack.esql.parser.ParserUtils.ParamClassification.VALUE; import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; public final class EsqlTestUtils { @@ -784,4 +789,17 @@ public static QueryParam paramAsIdentifier(String name, Object value) { public static QueryParam paramAsPattern(String name, Object value) { return new QueryParam(name, value, NULL, PATTERN); } + + /** + * Asserts that: + * 1. Cancellation exceptions are ignored when more relevant exceptions exist. + * 2. Transport exceptions are unwrapped, and the actual causes are reported to users. + */ + public static void assertEsqlFailure(Exception e) { + assertNotNull(e); + var cancellationFailure = ExceptionsHelper.unwrapCausesAndSuppressed(e, t -> t instanceof TaskCancelledException).orElse(null); + assertNull("cancellation exceptions must be ignored", cancellationFailure); + ExceptionsHelper.unwrapCausesAndSuppressed(e, t -> t instanceof RemoteTransportException) + .ifPresent(transportFailure -> assertNull("remote transport exception must be unwrapped", transportFailure.getCause())); + } } diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java index dab99a0f719d..c4da0bf32ef9 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EnrichIT.java @@ -143,6 +143,7 @@ protected EsqlQueryResponse run(EsqlQueryRequest request) { return client.execute(EsqlQueryAction.INSTANCE, request).actionGet(2, TimeUnit.MINUTES); } catch (Exception e) { logger.info("request failed", e); + EsqlTestUtils.assertEsqlFailure(e); ensureBlocksReleased(); } finally { setRequestCircuitBreakerLimit(null); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java index 37833d8aed2d..ec7ee8b61c2d 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionBreakerIT.java @@ -23,6 +23,7 @@ import org.elasticsearch.plugins.Plugin; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.junit.annotations.TestLogging; +import org.elasticsearch.xpack.esql.EsqlTestUtils; import java.util.ArrayList; import java.util.Collection; @@ -85,6 +86,7 @@ private EsqlQueryResponse runWithBreaking(EsqlQueryRequest request) throws Circu } catch (Exception e) { logger.info("request failed", e); ensureBlocksReleased(); + EsqlTestUtils.assertEsqlFailure(e); throw e; } finally { setRequestCircuitBreakerLimit(null); diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java index 1939f81353c0..abd4f6b49d7b 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlActionTaskIT.java @@ -36,6 +36,7 @@ import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.esql.EsqlTestUtils; import org.elasticsearch.xpack.esql.plugin.QueryPragmas; import org.junit.Before; @@ -338,7 +339,15 @@ private void assertCancelled(ActionFuture response) throws Ex */ assertThat( cancelException.getMessage(), - in(List.of("test cancel", "task cancelled", "request cancelled test cancel", "parent task was cancelled [test cancel]")) + in( + List.of( + "test cancel", + "task cancelled", + "request cancelled test cancel", + "parent task was cancelled [test cancel]", + "cancelled on failure" + ) + ) ); assertBusy( () -> assertThat( @@ -434,6 +443,7 @@ protected void doRun() throws Exception { allowedFetching.countDown(); } Exception failure = expectThrows(Exception.class, () -> future.actionGet().close()); + EsqlTestUtils.assertEsqlFailure(failure); assertThat(failure.getMessage(), containsString("failed to fetch pages")); // If we proceed without waiting for pages, we might cancel the main request before starting the data-node request. // As a result, the exchange sinks on data-nodes won't be removed until the inactive_timeout elapses, which is diff --git a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java index e9eada5def0d..72a60a6b6b92 100644 --- a/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java +++ b/x-pack/plugin/esql/src/internalClusterTest/java/org/elasticsearch/xpack/esql/action/EsqlDisruptionIT.java @@ -23,6 +23,7 @@ import org.elasticsearch.test.disruption.ServiceDisruptionScheme; import org.elasticsearch.test.transport.MockTransportService; import org.elasticsearch.transport.TransportSettings; +import org.elasticsearch.xpack.esql.EsqlTestUtils; import java.util.ArrayList; import java.util.Collection; @@ -111,6 +112,7 @@ private EsqlQueryResponse runQueryWithDisruption(EsqlQueryRequest request) { assertTrue("request must be failed or completed after clearing disruption", future.isDone()); ensureBlocksReleased(); logger.info("--> failed to execute esql query with disruption; retrying...", e); + EsqlTestUtils.assertEsqlFailure(e); return client().execute(EsqlQueryAction.INSTANCE, request).actionGet(2, TimeUnit.MINUTES); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java index eda6aadccc86..f6c23304c189 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/expression/predicate/operator/comparison/In.java @@ -151,14 +151,14 @@ public Expression replaceChildren(List newChildren) { public boolean foldable() { // QL's In fold()s to null, if value() is null, but isn't foldable() unless all children are // TODO: update this null check in QL too? - return Expressions.isNull(value) + return Expressions.isGuaranteedNull(value) || Expressions.foldable(children()) - || (Expressions.foldable(list) && list.stream().allMatch(Expressions::isNull)); + || (Expressions.foldable(list) && list.stream().allMatch(Expressions::isGuaranteedNull)); } @Override public Object fold() { - if (Expressions.isNull(value) || list.stream().allMatch(Expressions::isNull)) { + if (Expressions.isGuaranteedNull(value) || list.stream().allMatch(Expressions::isGuaranteedNull)) { return null; } return super.fold(); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java index 638fa1b8db45..4f97bf60bd86 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/FoldNull.java @@ -30,7 +30,7 @@ public Expression rule(Expression e) { // perform this early to prevent the rule from converting the null filter into nullifying the whole expression // P.S. this could be done inside the Aggregate but this place better centralizes the logic if (e instanceof AggregateFunction agg) { - if (Expressions.isNull(agg.filter())) { + if (Expressions.isGuaranteedNull(agg.filter())) { return agg.withFilter(Literal.of(agg.filter(), false)); } } @@ -38,13 +38,13 @@ public Expression rule(Expression e) { if (result != e) { return result; } else if (e instanceof In in) { - if (Expressions.isNull(in.value())) { + if (Expressions.isGuaranteedNull(in.value())) { return Literal.of(in, null); } } else if (e instanceof Alias == false && e.nullable() == Nullability.TRUE && e instanceof Categorize == false - && Expressions.anyMatch(e.children(), Expressions::isNull)) { + && Expressions.anyMatch(e.children(), Expressions::isGuaranteedNull)) { return Literal.of(e, null); } return e; diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PruneFilters.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PruneFilters.java index b6f7ac9e464f..00698d009ea2 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PruneFilters.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/PruneFilters.java @@ -29,7 +29,7 @@ protected LogicalPlan rule(Filter filter) { if (TRUE.equals(condition)) { return filter.child(); } - if (FALSE.equals(condition) || Expressions.isNull(condition)) { + if (FALSE.equals(condition) || Expressions.isGuaranteedNull(condition)) { return PruneEmptyPlans.skipPlan(filter); } } @@ -42,8 +42,8 @@ protected LogicalPlan rule(Filter filter) { private static Expression foldBinaryLogic(BinaryLogic binaryLogic) { if (binaryLogic instanceof Or or) { - boolean nullLeft = Expressions.isNull(or.left()); - boolean nullRight = Expressions.isNull(or.right()); + boolean nullLeft = Expressions.isGuaranteedNull(or.left()); + boolean nullRight = Expressions.isGuaranteedNull(or.right()); if (nullLeft && nullRight) { return new Literal(binaryLogic.source(), null, DataType.NULL); } @@ -55,7 +55,7 @@ private static Expression foldBinaryLogic(BinaryLogic binaryLogic) { } } if (binaryLogic instanceof And and) { - if (Expressions.isNull(and.left()) || Expressions.isNull(and.right())) { + if (Expressions.isGuaranteedNull(and.left()) || Expressions.isGuaranteedNull(and.right())) { return new Literal(binaryLogic.source(), null, DataType.NULL); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SplitInWithFoldableValue.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SplitInWithFoldableValue.java index 930b485dbd37..9e9ae6a9a559 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SplitInWithFoldableValue.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/optimizer/rules/logical/SplitInWithFoldableValue.java @@ -30,7 +30,7 @@ public Expression rule(In in) { List foldables = new ArrayList<>(in.list().size()); List nonFoldables = new ArrayList<>(in.list().size()); in.list().forEach(e -> { - if (e.foldable() && Expressions.isNull(e) == false) { // keep `null`s, needed for the 3VL + if (e.foldable() && Expressions.isGuaranteedNull(e) == false) { // keep `null`s, needed for the 3VL foldables.add(e); } else { nonFoldables.add(e); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java index 8d041ffbdf0e..8bd23230fcde 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeListener.java @@ -9,8 +9,8 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.RefCountingListener; +import org.elasticsearch.compute.EsqlRefCountingListener; import org.elasticsearch.compute.operator.DriverProfile; -import org.elasticsearch.compute.operator.FailureCollector; import org.elasticsearch.compute.operator.ResponseHeadersCollector; import org.elasticsearch.core.Nullable; import org.elasticsearch.core.Releasable; @@ -39,8 +39,7 @@ final class ComputeListener implements Releasable { private static final Logger LOGGER = LogManager.getLogger(ComputeService.class); - private final RefCountingListener refs; - private final FailureCollector failureCollector = new FailureCollector(); + private final EsqlRefCountingListener refs; private final AtomicBoolean cancelled = new AtomicBoolean(); private final CancellableTask task; private final TransportService transportService; @@ -105,7 +104,7 @@ private ComputeListener( : "clusterAlias and executionInfo must both be null or both non-null"; // listener that executes after all the sub-listeners refs (created via acquireCompute) have completed - this.refs = new RefCountingListener(1, ActionListener.wrap(ignored -> { + this.refs = new EsqlRefCountingListener(delegate.delegateFailure((l, ignored) -> { responseHeaders.finish(); ComputeResponse result; @@ -131,7 +130,7 @@ private ComputeListener( } } delegate.onResponse(result); - }, e -> delegate.onFailure(failureCollector.getFailure()))); + })); } private static void setFinalStatusAndShardCounts(String clusterAlias, EsqlExecutionInfo executionInfo) { @@ -191,7 +190,6 @@ private boolean isCCSListener(String computeClusterAlias) { */ ActionListener acquireAvoid() { return refs.acquire().delegateResponse((l, e) -> { - failureCollector.unwrapAndCollect(e); try { if (cancelled.compareAndSet(false, true)) { LOGGER.debug("cancelling ESQL task {} on failure", task); diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index ed037d24139f..9b59b98a7cdc 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -16,11 +16,11 @@ import org.elasticsearch.action.search.SearchShardsRequest; import org.elasticsearch.action.search.SearchShardsResponse; import org.elasticsearch.action.support.ChannelActionListener; -import org.elasticsearch.action.support.RefCountingListener; import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.cluster.node.DiscoveryNode; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.util.BigArrays; +import org.elasticsearch.compute.EsqlRefCountingListener; import org.elasticsearch.compute.data.BlockFactory; import org.elasticsearch.compute.data.Page; import org.elasticsearch.compute.operator.Driver; @@ -375,7 +375,7 @@ private void startComputeOnDataNodes( var lookupListener = ActionListener.releaseAfter(computeListener.acquireAvoid(), exchangeSource.addEmptySink()); // SearchShards API can_match is done in lookupDataNodes lookupDataNodes(parentTask, clusterAlias, requestFilter, concreteIndices, originalIndices, ActionListener.wrap(dataNodeResult -> { - try (RefCountingListener refs = new RefCountingListener(lookupListener)) { + try (EsqlRefCountingListener refs = new EsqlRefCountingListener(lookupListener)) { // update ExecutionInfo with shard counts (total and skipped) executionInfo.swapCluster( clusterAlias, @@ -436,7 +436,7 @@ private void startComputeOnRemoteClusters( ) { var queryPragmas = configuration.pragmas(); var linkExchangeListeners = ActionListener.releaseAfter(computeListener.acquireAvoid(), exchangeSource.addEmptySink()); - try (RefCountingListener refs = new RefCountingListener(linkExchangeListeners)) { + try (EsqlRefCountingListener refs = new EsqlRefCountingListener(linkExchangeListeners)) { for (RemoteCluster cluster : clusters) { final var childSessionId = newChildSession(sessionId); ExchangeService.openExchange( diff --git a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java index b76781f76f4a..c2a26845d4e8 100644 --- a/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java +++ b/x-pack/plugin/esql/src/test/java/org/elasticsearch/xpack/esql/optimizer/LogicalPlanOptimizerTests.java @@ -4820,7 +4820,7 @@ private static boolean oneLeaveIsNull(Expression e) { e.forEachUp(node -> { if (node.children().size() == 0) { - result.set(result.get() || Expressions.isNull(node)); + result.set(result.get() || Expressions.isGuaranteedNull(node)); } }); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java index 86c0128a3e53..1716057cdfe4 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceBaseRestTest.java @@ -21,6 +21,9 @@ import org.elasticsearch.test.cluster.ElasticsearchCluster; import org.elasticsearch.test.cluster.local.distribution.DistributionType; import org.elasticsearch.test.rest.ESRestTestCase; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; import org.junit.ClassRule; @@ -341,10 +344,21 @@ protected Deque streamInferOnMockService(String modelId, TaskTy return callAsync(endpoint, input); } + protected Deque unifiedCompletionInferOnMockService(String modelId, TaskType taskType, List input) + throws Exception { + var endpoint = Strings.format("_inference/%s/%s/_unified", taskType, modelId); + return callAsyncUnified(endpoint, input, "user"); + } + private Deque callAsync(String endpoint, List input) throws Exception { - var responseConsumer = new AsyncInferenceResponseConsumer(); var request = new Request("POST", endpoint); request.setJsonEntity(jsonBody(input, null)); + + return execAsyncCall(request); + } + + private Deque execAsyncCall(Request request) throws Exception { + var responseConsumer = new AsyncInferenceResponseConsumer(); request.setOptions(RequestOptions.DEFAULT.toBuilder().setHttpAsyncResponseConsumerFactory(() -> responseConsumer).build()); var latch = new CountDownLatch(1); client().performRequestAsync(request, new ResponseListener() { @@ -362,6 +376,22 @@ public void onFailure(Exception exception) { return responseConsumer.events(); } + private Deque callAsyncUnified(String endpoint, List input, String role) throws Exception { + var request = new Request("POST", endpoint); + + request.setJsonEntity(createUnifiedJsonBody(input, role)); + return execAsyncCall(request); + } + + private String createUnifiedJsonBody(List input, String role) throws IOException { + var messages = input.stream().map(i -> Map.of("content", i, "role", role)).toList(); + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + builder.startObject(); + builder.field("messages", messages); + builder.endObject(); + return org.elasticsearch.common.Strings.toString(builder); + } + protected Map infer(String modelId, TaskType taskType, List input) throws IOException { var endpoint = Strings.format("_inference/%s/%s", taskType, modelId); return inferInternal(endpoint, input, null, Map.of()); diff --git a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java index 604e1d4f553b..2099ec8287a7 100644 --- a/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java +++ b/x-pack/plugin/inference/qa/inference-service-tests/src/javaRestTest/java/org/elasticsearch/xpack/inference/InferenceCrudIT.java @@ -11,13 +11,18 @@ import org.apache.http.util.EntityUtils; import org.elasticsearch.client.ResponseException; +import org.elasticsearch.common.Strings; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.inference.services.elastic.ElasticInferenceServiceFeature; import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Iterator; import java.util.List; import java.util.Map; import java.util.Objects; @@ -481,6 +486,56 @@ public void testSupportedStream() throws Exception { } } + public void testUnifiedCompletionInference() throws Exception { + String modelId = "streaming"; + putModel(modelId, mockCompletionServiceModelConfig(TaskType.COMPLETION)); + var singleModel = getModel(modelId); + assertEquals(modelId, singleModel.get("inference_id")); + assertEquals(TaskType.COMPLETION.toString(), singleModel.get("task_type")); + + var input = IntStream.range(1, 2 + randomInt(8)).mapToObj(i -> randomUUID()).toList(); + try { + var events = unifiedCompletionInferOnMockService(modelId, TaskType.COMPLETION, input); + var expectedResponses = expectedResultsIterator(input); + assertThat(events.size(), equalTo((input.size() + 1) * 2)); + events.forEach(event -> { + switch (event.name()) { + case EVENT -> assertThat(event.value(), equalToIgnoringCase("message")); + case DATA -> assertThat(event.value(), equalTo(expectedResponses.next())); + } + }); + } finally { + deleteModel(modelId); + } + } + + private static Iterator expectedResultsIterator(List input) { + return Stream.concat(input.stream().map(String::toUpperCase).map(InferenceCrudIT::expectedResult), Stream.of("[DONE]")).iterator(); + } + + private static String expectedResult(String input) { + try { + XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); + builder.startObject(); + builder.field("id", "id"); + builder.startArray("choices"); + builder.startObject(); + builder.startObject("delta"); + builder.field("content", input); + builder.endObject(); + builder.field("index", 0); + builder.endObject(); + builder.endArray(); + builder.field("model", "gpt-4o-2024-08-06"); + builder.field("object", "chat.completion.chunk"); + builder.endObject(); + + return Strings.toString(builder); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + public void testGetZeroModels() throws IOException { var models = getModels("_all", TaskType.COMPLETION); assertThat(models, empty()); diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index ae11a02d312e..f5f682b143a7 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -31,6 +31,7 @@ import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; @@ -132,6 +133,16 @@ public void infer( } } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + listener.onFailure(new UnsupportedOperationException("unifiedCompletionInfer not supported")); + } + @Override public void chunkedInfer( Model model, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index 9320571572f0..fa1e27005c28 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -29,6 +29,7 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; @@ -120,6 +121,16 @@ public void infer( } } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + listener.onFailure(new UnsupportedOperationException("unifiedCompletionInfer not supported")); + } + @Override public void chunkedInfer( Model model, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index fe0223cce032..64569fd8c5c6 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -29,6 +29,7 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; @@ -123,6 +124,16 @@ public void infer( } } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + throw new UnsupportedOperationException("unifiedCompletionInfer not supported"); + } + @Override public void chunkedInfer( Model model, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 6d7983bc8cb5..f7a05a27354e 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -30,12 +30,14 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; import java.io.IOException; import java.util.EnumSet; @@ -121,6 +123,24 @@ public void infer( } } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + switch (model.getConfigurations().getTaskType()) { + case COMPLETION -> listener.onResponse(makeUnifiedResults(request)); + default -> listener.onFailure( + new ElasticsearchStatusException( + TaskType.unsupportedTaskTypeErrorMsg(model.getConfigurations().getTaskType(), name()), + RestStatus.BAD_REQUEST + ) + ); + } + } + private StreamingChatCompletionResults makeResults(List input) { var responseIter = input.stream().map(String::toUpperCase).iterator(); return new StreamingChatCompletionResults(subscriber -> { @@ -152,6 +172,59 @@ private ChunkedToXContent completionChunk(String delta) { ); } + private StreamingUnifiedChatCompletionResults makeUnifiedResults(UnifiedCompletionRequest request) { + var responseIter = request.messages().stream().map(message -> message.content().toString().toUpperCase()).iterator(); + return new StreamingUnifiedChatCompletionResults(subscriber -> { + subscriber.onSubscribe(new Flow.Subscription() { + @Override + public void request(long n) { + if (responseIter.hasNext()) { + subscriber.onNext(unifiedCompletionChunk(responseIter.next())); + } else { + subscriber.onComplete(); + } + } + + @Override + public void cancel() {} + }); + }); + } + + /* + The response format looks like this + { + "id": "chatcmpl-AarrzyuRflye7yzDF4lmVnenGmQCF", + "choices": [ + { + "delta": { + "content": " information" + }, + "index": 0 + } + ], + "model": "gpt-4o-2024-08-06", + "object": "chat.completion.chunk" + } + */ + private ChunkedToXContent unifiedCompletionChunk(String delta) { + return params -> Iterators.concat( + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.field("id", "id"), + ChunkedToXContentHelper.startArray("choices"), + ChunkedToXContentHelper.startObject(), + ChunkedToXContentHelper.startObject("delta"), + ChunkedToXContentHelper.field("content", delta), + ChunkedToXContentHelper.endObject(), + ChunkedToXContentHelper.field("index", 0), + ChunkedToXContentHelper.endObject(), + ChunkedToXContentHelper.endArray(), + ChunkedToXContentHelper.field("model", "gpt-4o-2024-08-06"), + ChunkedToXContentHelper.field("object", "chat.completion.chunk"), + ChunkedToXContentHelper.endObject() + ); + } + @Override public void chunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java index c82f287792a7..67892dfe7862 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceFeatures.java @@ -33,6 +33,8 @@ public Set getFeatures() { ); } + private static final NodeFeature SEMANTIC_TEXT_HIGHLIGHTER = new NodeFeature("semantic_text.highlighter"); + @Override public Set getTestFeatures() { return Set.of( @@ -40,7 +42,8 @@ public Set getTestFeatures() { SemanticTextFieldMapper.SEMANTIC_TEXT_SINGLE_FIELD_UPDATE_FIX, SemanticTextFieldMapper.SEMANTIC_TEXT_DELETE_FIX, SemanticTextFieldMapper.SEMANTIC_TEXT_ZERO_SIZE_FIX, - SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX + SemanticTextFieldMapper.SEMANTIC_TEXT_ALWAYS_EMIT_INFERENCE_ID_FIX, + SEMANTIC_TEXT_HIGHLIGHTER ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java index 673b841317a3..a4187f4c4fa9 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferenceNamedWriteablesProvider.java @@ -16,6 +16,7 @@ import org.elasticsearch.inference.SecretSettings; import org.elasticsearch.inference.ServiceSettings; import org.elasticsearch.inference.TaskSettings; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.ErrorChunkedInferenceResults; import org.elasticsearch.xpack.core.inference.results.InferenceChunkedSparseEmbeddingResults; @@ -137,11 +138,18 @@ public static List getNamedWriteables() { addEisNamedWriteables(namedWriteables); addAlibabaCloudSearchNamedWriteables(namedWriteables); + addUnifiedNamedWriteables(namedWriteables); + namedWriteables.addAll(StreamingTaskManager.namedWriteables()); return namedWriteables; } + private static void addUnifiedNamedWriteables(List namedWriteables) { + var writeables = UnifiedCompletionRequest.getNamedWriteables(); + namedWriteables.addAll(writeables); + } + private static void addAmazonBedrockNamedWriteables(List namedWriteables) { namedWriteables.add( new NamedWriteableRegistry.Entry( diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java index 3c14e51a3c2d..148a78445636 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/InferencePlugin.java @@ -37,6 +37,7 @@ import org.elasticsearch.plugins.SystemIndexPlugin; import org.elasticsearch.rest.RestController; import org.elasticsearch.rest.RestHandler; +import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; import org.elasticsearch.search.rank.RankBuilder; import org.elasticsearch.search.rank.RankDoc; import org.elasticsearch.threadpool.ExecutorBuilder; @@ -50,6 +51,7 @@ import org.elasticsearch.xpack.core.inference.action.GetInferenceServicesAction; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.action.PutInferenceModelAction; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; import org.elasticsearch.xpack.core.inference.action.UpdateInferenceModelAction; import org.elasticsearch.xpack.inference.action.TransportDeleteInferenceEndpointAction; import org.elasticsearch.xpack.inference.action.TransportGetInferenceDiagnosticsAction; @@ -58,6 +60,7 @@ import org.elasticsearch.xpack.inference.action.TransportInferenceAction; import org.elasticsearch.xpack.inference.action.TransportInferenceUsageAction; import org.elasticsearch.xpack.inference.action.TransportPutInferenceModelAction; +import org.elasticsearch.xpack.inference.action.TransportUnifiedCompletionInferenceAction; import org.elasticsearch.xpack.inference.action.TransportUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.action.filter.ShardBulkInferenceActionFilter; import org.elasticsearch.xpack.inference.common.Truncator; @@ -67,6 +70,7 @@ import org.elasticsearch.xpack.inference.external.http.retry.RetrySettings; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.RequestExecutorServiceSettings; +import org.elasticsearch.xpack.inference.highlight.SemanticTextHighlighter; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.mapper.OffsetSourceFieldMapper; import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; @@ -84,6 +88,7 @@ import org.elasticsearch.xpack.inference.rest.RestInferenceAction; import org.elasticsearch.xpack.inference.rest.RestPutInferenceModelAction; import org.elasticsearch.xpack.inference.rest.RestStreamInferenceAction; +import org.elasticsearch.xpack.inference.rest.RestUnifiedCompletionInferenceAction; import org.elasticsearch.xpack.inference.rest.RestUpdateInferenceModelAction; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchService; @@ -157,8 +162,9 @@ public InferencePlugin(Settings settings) { @Override public List> getActions() { - return List.of( + var availableActions = List.of( new ActionHandler<>(InferenceAction.INSTANCE, TransportInferenceAction.class), + new ActionHandler<>(GetInferenceModelAction.INSTANCE, TransportGetInferenceModelAction.class), new ActionHandler<>(PutInferenceModelAction.INSTANCE, TransportPutInferenceModelAction.class), new ActionHandler<>(UpdateInferenceModelAction.INSTANCE, TransportUpdateInferenceModelAction.class), @@ -167,6 +173,13 @@ public InferencePlugin(Settings settings) { new ActionHandler<>(GetInferenceDiagnosticsAction.INSTANCE, TransportGetInferenceDiagnosticsAction.class), new ActionHandler<>(GetInferenceServicesAction.INSTANCE, TransportGetInferenceServicesAction.class) ); + + List> conditionalActions = + UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled() + ? List.of(new ActionHandler<>(UnifiedCompletionAction.INSTANCE, TransportUnifiedCompletionInferenceAction.class)) + : List.of(); + + return Stream.concat(availableActions.stream(), conditionalActions.stream()).toList(); } @Override @@ -181,7 +194,7 @@ public List getRestHandlers( Supplier nodesInCluster, Predicate clusterSupportsFeature ) { - return List.of( + var availableRestActions = List.of( new RestInferenceAction(), new RestStreamInferenceAction(), new RestGetInferenceModelAction(), @@ -191,6 +204,11 @@ public List getRestHandlers( new RestGetInferenceDiagnosticsAction(), new RestGetInferenceServicesAction() ); + List conditionalRestActions = UnifiedCompletionFeature.UNIFIED_COMPLETION_FEATURE_FLAG.isEnabled() + ? List.of(new RestUnifiedCompletionInferenceAction()) + : List.of(); + + return Stream.concat(availableRestActions.stream(), conditionalRestActions.stream()).toList(); } @Override @@ -417,4 +435,9 @@ public List> getRetrievers() { new RetrieverSpec<>(new ParseField(RandomRankBuilder.NAME), RandomRankRetrieverBuilder::fromXContent) ); } + + @Override + public Map getHighlighters() { + return Map.of(SemanticTextHighlighter.NAME, new SemanticTextHighlighter()); + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnifiedCompletionFeature.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnifiedCompletionFeature.java new file mode 100644 index 000000000000..3e13d0c1e39d --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/UnifiedCompletionFeature.java @@ -0,0 +1,20 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference; + +import org.elasticsearch.common.util.FeatureFlag; + +/** + * Unified Completion feature flag. When the feature is complete, this flag will be removed. + * Enable feature via JVM option: `-Des.inference_unified_feature_flag_enabled=true`. + */ +public class UnifiedCompletionFeature { + public static final FeatureFlag UNIFIED_COMPLETION_FEATURE_FLAG = new FeatureFlag("inference_unified"); + + private UnifiedCompletionFeature() {} +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java new file mode 100644 index 000000000000..2a0e8e177527 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceAction.java @@ -0,0 +1,250 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.action.support.HandledTransportAction; +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.common.util.concurrent.EsExecutors; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.core.Nullable; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.tasks.Task; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; +import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; + +import java.util.function.Supplier; +import java.util.stream.Collectors; + +import static org.elasticsearch.core.Strings.format; +import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; +import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; + +public abstract class BaseTransportInferenceAction extends HandledTransportAction< + Request, + InferenceAction.Response> { + + private static final Logger log = LogManager.getLogger(BaseTransportInferenceAction.class); + private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference"; + private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]"; + private final ModelRegistry modelRegistry; + private final InferenceServiceRegistry serviceRegistry; + private final InferenceStats inferenceStats; + private final StreamingTaskManager streamingTaskManager; + + public BaseTransportInferenceAction( + String inferenceActionName, + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager, + Writeable.Reader requestReader + ) { + super(inferenceActionName, transportService, actionFilters, requestReader, EsExecutors.DIRECT_EXECUTOR_SERVICE); + this.modelRegistry = modelRegistry; + this.serviceRegistry = serviceRegistry; + this.inferenceStats = inferenceStats; + this.streamingTaskManager = streamingTaskManager; + } + + @Override + protected void doExecute(Task task, Request request, ActionListener listener) { + var timer = InferenceTimer.start(); + + var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> { + var service = serviceRegistry.getService(unparsedModel.service()); + try { + validationHelper(service::isEmpty, () -> unknownServiceException(unparsedModel.service(), request.getInferenceEntityId())); + validationHelper( + () -> request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false, + () -> requestModelTaskTypeMismatchException(request.getTaskType(), unparsedModel.taskType()) + ); + validationHelper( + () -> isInvalidTaskTypeForInferenceEndpoint(request, unparsedModel), + () -> createInvalidTaskTypeException(request, unparsedModel) + ); + } catch (Exception e) { + recordMetrics(unparsedModel, timer, e); + listener.onFailure(e); + return; + } + + var model = service.get() + .parsePersistedConfigWithSecrets( + unparsedModel.inferenceEntityId(), + unparsedModel.taskType(), + unparsedModel.settings(), + unparsedModel.secrets() + ); + inferOnServiceWithMetrics(model, request, service.get(), timer, listener); + }, e -> { + try { + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e)); + } catch (Exception metricsException) { + log.atDebug().withThrowable(metricsException).log("Failed to record metrics when the model is missing, dropping metrics"); + } + listener.onFailure(e); + }); + + modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener); + } + + private static void validationHelper(Supplier validationFailure, Supplier exceptionCreator) { + if (validationFailure.get()) { + throw exceptionCreator.get(); + } + } + + protected abstract boolean isInvalidTaskTypeForInferenceEndpoint(Request request, UnparsedModel unparsedModel); + + protected abstract ElasticsearchStatusException createInvalidTaskTypeException(Request request, UnparsedModel unparsedModel); + + private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { + try { + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); + } catch (Exception e) { + log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics"); + } + } + + private void inferOnServiceWithMetrics( + Model model, + Request request, + InferenceService service, + InferenceTimer timer, + ActionListener listener + ) { + inferenceStats.requestCount().incrementBy(1, modelAttributes(model)); + inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> { + if (request.isStreaming()) { + var taskProcessor = streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION); + inferenceResults.publisher().subscribe(taskProcessor); + + var instrumentedStream = new PublisherWithMetrics(timer, model); + taskProcessor.subscribe(instrumentedStream); + + listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream)); + } else { + recordMetrics(model, timer, null); + listener.onResponse(new InferenceAction.Response(inferenceResults)); + } + }, e -> { + recordMetrics(model, timer, e); + listener.onFailure(e); + })); + } + + private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) { + try { + inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); + } catch (Exception e) { + log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics"); + } + } + + private void inferOnService(Model model, Request request, InferenceService service, ActionListener listener) { + if (request.isStreaming() == false || service.canStream(request.getTaskType())) { + doInference(model, request, service, listener); + } else { + listener.onFailure(unsupportedStreamingTaskException(request, service)); + } + } + + protected abstract void doInference( + Model model, + Request request, + InferenceService service, + ActionListener listener + ); + + private ElasticsearchStatusException unsupportedStreamingTaskException(Request request, InferenceService service) { + var supportedTasks = service.supportedStreamingTasks(); + if (supportedTasks.isEmpty()) { + return new ElasticsearchStatusException( + format("Streaming is not allowed for service [%s].", service.name()), + RestStatus.METHOD_NOT_ALLOWED + ); + } else { + var validTasks = supportedTasks.stream().map(TaskType::toString).collect(Collectors.joining(",")); + return new ElasticsearchStatusException( + format( + "Streaming is not allowed for service [%s] and task [%s]. Supported tasks: [%s]", + service.name(), + request.getTaskType(), + validTasks + ), + RestStatus.METHOD_NOT_ALLOWED + ); + } + } + + private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) { + return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId); + } + + private static ElasticsearchStatusException requestModelTaskTypeMismatchException(TaskType requested, TaskType expected) { + return new ElasticsearchStatusException( + "Incompatible task_type, the requested type [{}] does not match the model type [{}]", + RestStatus.BAD_REQUEST, + requested, + expected + ); + } + + private class PublisherWithMetrics extends DelegatingProcessor { + + private final InferenceTimer timer; + private final Model model; + + private PublisherWithMetrics(InferenceTimer timer, Model model) { + this.timer = timer; + this.model = model; + } + + @Override + protected void next(ChunkedToXContent item) { + downstream().onNext(item); + } + + @Override + public void onError(Throwable throwable) { + recordMetrics(model, timer, throwable); + super.onError(throwable); + } + + @Override + protected void onCancel() { + recordMetrics(model, timer, null); + super.onCancel(); + } + + @Override + public void onComplete() { + recordMetrics(model, timer, null); + super.onComplete(); + } + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index ba9ab3c13373..08e6d869a553 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -7,47 +7,22 @@ package org.elasticsearch.xpack.inference.action; -import org.apache.logging.log4j.LogManager; -import org.apache.logging.log4j.Logger; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.action.support.HandledTransportAction; -import org.elasticsearch.common.util.concurrent.EsExecutors; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.TaskType; import org.elasticsearch.inference.UnparsedModel; import org.elasticsearch.injection.guice.Inject; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.tasks.Task; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; -import org.elasticsearch.xpack.inference.common.DelegatingProcessor; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; -import org.elasticsearch.xpack.inference.telemetry.InferenceTimer; -import java.util.stream.Collectors; - -import static org.elasticsearch.core.Strings.format; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.modelAttributes; -import static org.elasticsearch.xpack.inference.telemetry.InferenceStats.responseAttributes; - -public class TransportInferenceAction extends HandledTransportAction { - private static final Logger log = LogManager.getLogger(TransportInferenceAction.class); - private static final String STREAMING_INFERENCE_TASK_TYPE = "streaming_inference"; - private static final String STREAMING_TASK_ACTION = "xpack/inference/streaming_inference[n]"; - - private final ModelRegistry modelRegistry; - private final InferenceServiceRegistry serviceRegistry; - private final InferenceStats inferenceStats; - private final StreamingTaskManager streamingTaskManager; +public class TransportInferenceAction extends BaseTransportInferenceAction { @Inject public TransportInferenceAction( @@ -58,184 +33,44 @@ public TransportInferenceAction( InferenceStats inferenceStats, StreamingTaskManager streamingTaskManager ) { - super(InferenceAction.NAME, transportService, actionFilters, InferenceAction.Request::new, EsExecutors.DIRECT_EXECUTOR_SERVICE); - this.modelRegistry = modelRegistry; - this.serviceRegistry = serviceRegistry; - this.inferenceStats = inferenceStats; - this.streamingTaskManager = streamingTaskManager; + super( + InferenceAction.NAME, + transportService, + actionFilters, + modelRegistry, + serviceRegistry, + inferenceStats, + streamingTaskManager, + InferenceAction.Request::new + ); } @Override - protected void doExecute(Task task, InferenceAction.Request request, ActionListener listener) { - var timer = InferenceTimer.start(); - - var getModelListener = ActionListener.wrap((UnparsedModel unparsedModel) -> { - var service = serviceRegistry.getService(unparsedModel.service()); - if (service.isEmpty()) { - var e = unknownServiceException(unparsedModel.service(), request.getInferenceEntityId()); - recordMetrics(unparsedModel, timer, e); - listener.onFailure(e); - return; - } - - if (request.getTaskType().isAnyOrSame(unparsedModel.taskType()) == false) { - // not the wildcard task type and not the model task type - var e = incompatibleTaskTypeException(request.getTaskType(), unparsedModel.taskType()); - recordMetrics(unparsedModel, timer, e); - listener.onFailure(e); - return; - } - - var model = service.get() - .parsePersistedConfigWithSecrets( - unparsedModel.inferenceEntityId(), - unparsedModel.taskType(), - unparsedModel.settings(), - unparsedModel.secrets() - ); - inferOnServiceWithMetrics(model, request, service.get(), timer, listener); - }, e -> { - try { - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(e)); - } catch (Exception metricsException) { - log.atDebug().withThrowable(metricsException).log("Failed to record metrics when the model is missing, dropping metrics"); - } - listener.onFailure(e); - }); - - modelRegistry.getModelWithSecrets(request.getInferenceEntityId(), getModelListener); - } - - private void recordMetrics(UnparsedModel model, InferenceTimer timer, @Nullable Throwable t) { - try { - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); - } catch (Exception e) { - log.atDebug().withThrowable(e).log("Failed to record metrics with an unparsed model, dropping metrics"); - } - } - - private void inferOnServiceWithMetrics( - Model model, - InferenceAction.Request request, - InferenceService service, - InferenceTimer timer, - ActionListener listener - ) { - inferenceStats.requestCount().incrementBy(1, modelAttributes(model)); - inferOnService(model, request, service, ActionListener.wrap(inferenceResults -> { - if (request.isStreaming()) { - var taskProcessor = streamingTaskManager.create(STREAMING_INFERENCE_TASK_TYPE, STREAMING_TASK_ACTION); - inferenceResults.publisher().subscribe(taskProcessor); - - var instrumentedStream = new PublisherWithMetrics(timer, model); - taskProcessor.subscribe(instrumentedStream); - - listener.onResponse(new InferenceAction.Response(inferenceResults, instrumentedStream)); - } else { - recordMetrics(model, timer, null); - listener.onResponse(new InferenceAction.Response(inferenceResults)); - } - }, e -> { - recordMetrics(model, timer, e); - listener.onFailure(e); - })); + protected boolean isInvalidTaskTypeForInferenceEndpoint(InferenceAction.Request request, UnparsedModel unparsedModel) { + return false; } - private void recordMetrics(Model model, InferenceTimer timer, @Nullable Throwable t) { - try { - inferenceStats.inferenceDuration().record(timer.elapsedMillis(), responseAttributes(model, t)); - } catch (Exception e) { - log.atDebug().withThrowable(e).log("Failed to record metrics with a parsed model, dropping metrics"); - } + @Override + protected ElasticsearchStatusException createInvalidTaskTypeException(InferenceAction.Request request, UnparsedModel unparsedModel) { + return null; } - private void inferOnService( + @Override + protected void doInference( Model model, InferenceAction.Request request, InferenceService service, ActionListener listener ) { - if (request.isStreaming() == false || service.canStream(request.getTaskType())) { - service.infer( - model, - request.getQuery(), - request.getInput(), - request.isStreaming(), - request.getTaskSettings(), - request.getInputType(), - request.getInferenceTimeout(), - listener - ); - } else { - listener.onFailure(unsupportedStreamingTaskException(request, service)); - } - } - - private ElasticsearchStatusException unsupportedStreamingTaskException(InferenceAction.Request request, InferenceService service) { - var supportedTasks = service.supportedStreamingTasks(); - if (supportedTasks.isEmpty()) { - return new ElasticsearchStatusException( - format("Streaming is not allowed for service [%s].", service.name()), - RestStatus.METHOD_NOT_ALLOWED - ); - } else { - var validTasks = supportedTasks.stream().map(TaskType::toString).collect(Collectors.joining(",")); - return new ElasticsearchStatusException( - format( - "Streaming is not allowed for service [%s] and task [%s]. Supported tasks: [%s]", - service.name(), - request.getTaskType(), - validTasks - ), - RestStatus.METHOD_NOT_ALLOWED - ); - } - } - - private static ElasticsearchStatusException unknownServiceException(String service, String inferenceId) { - return new ElasticsearchStatusException("Unknown service [{}] for model [{}]. ", RestStatus.BAD_REQUEST, service, inferenceId); - } - - private static ElasticsearchStatusException incompatibleTaskTypeException(TaskType requested, TaskType expected) { - return new ElasticsearchStatusException( - "Incompatible task_type, the requested type [{}] does not match the model type [{}]", - RestStatus.BAD_REQUEST, - requested, - expected + service.infer( + model, + request.getQuery(), + request.getInput(), + request.isStreaming(), + request.getTaskSettings(), + request.getInputType(), + request.getInferenceTimeout(), + listener ); } - - private class PublisherWithMetrics extends DelegatingProcessor { - private final InferenceTimer timer; - private final Model model; - - private PublisherWithMetrics(InferenceTimer timer, Model model) { - this.timer = timer; - this.model = model; - } - - @Override - protected void next(ChunkedToXContent item) { - downstream().onNext(item); - } - - @Override - public void onError(Throwable throwable) { - recordMetrics(model, timer, throwable); - super.onError(throwable); - } - - @Override - protected void onCancel() { - recordMetrics(model, timer, null); - super.onCancel(); - } - - @Override - public void onComplete() { - recordMetrics(model, timer, null); - super.onComplete(); - } - } - } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java new file mode 100644 index 000000000000..f0906231d8f4 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionInferenceAction.java @@ -0,0 +1,77 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.injection.guice.Inject; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; + +public class TransportUnifiedCompletionInferenceAction extends BaseTransportInferenceAction { + + @Inject + public TransportUnifiedCompletionInferenceAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager + ) { + super( + UnifiedCompletionAction.NAME, + transportService, + actionFilters, + modelRegistry, + serviceRegistry, + inferenceStats, + streamingTaskManager, + UnifiedCompletionAction.Request::new + ); + } + + @Override + protected boolean isInvalidTaskTypeForInferenceEndpoint(UnifiedCompletionAction.Request request, UnparsedModel unparsedModel) { + return request.getTaskType().isAnyOrSame(TaskType.COMPLETION) == false || unparsedModel.taskType() != TaskType.COMPLETION; + } + + @Override + protected ElasticsearchStatusException createInvalidTaskTypeException( + UnifiedCompletionAction.Request request, + UnparsedModel unparsedModel + ) { + return new ElasticsearchStatusException( + "Incompatible task_type for unified API, the requested type [{}] must be one of [{}]", + RestStatus.BAD_REQUEST, + request.getTaskType(), + TaskType.COMPLETION.toString() + ); + } + + @Override + protected void doInference( + Model model, + UnifiedCompletionAction.Request request, + InferenceService service, + ActionListener listener + ) { + service.unifiedCompletionInfer(model, request.getUnifiedCompletionRequest(), null, listener); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java index 03e794e42c3a..eda3fc0f3bfd 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/DelegatingProcessor.java @@ -9,7 +9,14 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; - +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.Iterator; import java.util.concurrent.Flow; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; @@ -25,6 +32,33 @@ public abstract class DelegatingProcessor implements Flow.Processor private Flow.Subscriber downstream; private Flow.Subscription upstream; + public static Deque parseEvent( + Deque item, + ParseChunkFunction parseFunction, + XContentParserConfiguration parserConfig, + Logger logger + ) throws Exception { + var results = new ArrayDeque(item.size()); + for (ServerSentEvent event : item) { + if (ServerSentEventField.DATA == event.name() && event.hasValue()) { + try { + var delta = parseFunction.apply(parserConfig, event); + delta.forEachRemaining(results::offer); + } catch (Exception e) { + logger.warn("Failed to parse event from inference provider: {}", event); + throw e; + } + } + } + + return results; + } + + @FunctionalInterface + public interface ParseChunkFunction { + Iterator apply(XContentParserConfiguration parserConfig, ServerSentEvent event) throws IOException; + } + @Override public void subscribe(Flow.Subscriber subscriber) { if (downstream != null) { @@ -51,7 +85,7 @@ public void request(long n) { if (isClosed.get()) { downstream.onComplete(); } else if (upstream != null) { - upstream.request(n); + upstreamRequest(n); } else { pendingRequests.accumulateAndGet(n, Long::sum); } @@ -67,6 +101,13 @@ public void cancel() { }; } + /** + * Guaranteed to be called when the upstream is set and this processor had not been closed. + */ + protected void upstreamRequest(long n) { + upstream.request(n); + } + protected void onCancel() {} @Override diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java index 4e97554b5644..b43e5ab70e2f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableAction.java @@ -12,7 +12,6 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -34,13 +33,7 @@ public SingleInputSenderExecutableAction( @Override public void execute(InferenceInputs inferenceInputs, TimeValue timeout, ActionListener listener) { - if (inferenceInputs instanceof DocumentsOnlyInput == false) { - listener.onFailure(new ElasticsearchStatusException("Invalid inference input type", RestStatus.INTERNAL_SERVER_ERROR)); - return; - } - - var docsOnlyInput = (DocumentsOnlyInput) inferenceInputs; - if (docsOnlyInput.getInputs().size() > 1) { + if (inferenceInputs.inputSize() > 1) { listener.onFailure( new ElasticsearchStatusException(requestTypeForInputValidationError + " only accepts 1 input", RestStatus.BAD_REQUEST) ); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java index 9c83264b5581..bd5c53d589df 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreator.java @@ -26,7 +26,7 @@ * Provides a way to construct an {@link ExecutableAction} using the visitor pattern based on the openai model type. */ public class OpenAiActionCreator implements OpenAiActionVisitor { - private static final String COMPLETION_ERROR_PREFIX = "OpenAI chat completions"; + public static final String COMPLETION_ERROR_PREFIX = "OpenAI chat completions"; private final Sender sender; private final ServiceComponents serviceComponents; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java index a0a44e62f9f7..e7a960f1316f 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AlibabaCloudSearchCompletionRequestManager.java @@ -69,7 +69,7 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - List input = DocumentsOnlyInput.of(inferenceInputs).getInputs(); + List input = inferenceInputs.castTo(ChatCompletionInput.class).getInputs(); AlibabaCloudSearchCompletionRequest request = new AlibabaCloudSearchCompletionRequest(account, input, model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java index 69a5c665feb8..3929585a0745 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AmazonBedrockChatCompletionRequestManager.java @@ -44,10 +44,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, docsInput); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + var requestEntity = AmazonBedrockChatCompletionEntityFactory.createEntity(model, inputs); var request = new AmazonBedrockChatCompletionRequest(model, requestEntity, timeout, stream); var responseHandler = new AmazonBedrockChatCompletionResponseHandler(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java index 5418b3dd9840..6d4aeb9e31ba 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AnthropicCompletionRequestManager.java @@ -46,10 +46,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(docsInput, model, stream); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + AnthropicChatCompletionRequest request = new AnthropicChatCompletionRequest(inputs, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java index 21cec68b14a4..affd2e3a7760 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureAiStudioChatCompletionRequestManager.java @@ -41,10 +41,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, docsInput, stream); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + AzureAiStudioChatCompletionRequest request = new AzureAiStudioChatCompletionRequest(model, inputs, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java index d036559ec3dc..c2f5f3e9db5e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/AzureOpenAiCompletionRequestManager.java @@ -46,10 +46,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(docsInput, model, stream); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + AzureOpenAiCompletionRequest request = new AzureOpenAiCompletionRequest(inputs, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java new file mode 100644 index 000000000000..928da95d9c2f --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/ChatCompletionInput.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import java.util.List; +import java.util.Objects; + +/** + * This class encapsulates the input text passed by the request and indicates whether the response should be streamed. + * The main difference between this class and {@link UnifiedChatInput} is this should only be used for + * {@link org.elasticsearch.inference.TaskType#COMPLETION} originating through the + * {@link org.elasticsearch.inference.InferenceService#infer} code path. These are requests sent to the + * API without using the _unified route. + */ +public class ChatCompletionInput extends InferenceInputs { + private final List input; + + public ChatCompletionInput(List input) { + this(input, false); + } + + public ChatCompletionInput(List input, boolean stream) { + super(stream); + this.input = Objects.requireNonNull(input); + } + + public List getInputs() { + return this.input; + } + + public int inputSize() { + return input.size(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java index ae46fbe0fef8..40cd03c87664 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/CohereCompletionRequestManager.java @@ -50,10 +50,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - CohereCompletionRequest request = new CohereCompletionRequest(docsInput, model, stream); + var chatCompletionInput = inferenceInputs.castTo(ChatCompletionInput.class); + var inputs = chatCompletionInput.getInputs(); + var stream = chatCompletionInput.stream(); + CohereCompletionRequest request = new CohereCompletionRequest(inputs, model, stream); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java index 8cf411d84c93..3feb79d3de6c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java @@ -14,30 +14,28 @@ public class DocumentsOnlyInput extends InferenceInputs { public static DocumentsOnlyInput of(InferenceInputs inferenceInputs) { if (inferenceInputs instanceof DocumentsOnlyInput == false) { - throw createUnsupportedTypeException(inferenceInputs); + throw createUnsupportedTypeException(inferenceInputs, DocumentsOnlyInput.class); } return (DocumentsOnlyInput) inferenceInputs; } private final List input; - private final boolean stream; public DocumentsOnlyInput(List input) { this(input, false); } public DocumentsOnlyInput(List input, boolean stream) { - super(); + super(stream); this.input = Objects.requireNonNull(input); - this.stream = stream; } public List getInputs() { return this.input; } - public boolean stream() { - return stream; + public int inputSize() { + return input.size(); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java index abe50c6fae3f..0097f9c08ea2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/GoogleAiStudioCompletionRequestManager.java @@ -51,7 +51,10 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest(DocumentsOnlyInput.of(inferenceInputs), model); + GoogleAiStudioCompletionRequest request = new GoogleAiStudioCompletionRequest( + inferenceInputs.castTo(ChatCompletionInput.class), + model + ); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java index dd241857ef0c..e85ea6f1d9b3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputs.java @@ -10,7 +10,29 @@ import org.elasticsearch.common.Strings; public abstract class InferenceInputs { - public static IllegalArgumentException createUnsupportedTypeException(InferenceInputs inferenceInputs) { - return new IllegalArgumentException(Strings.format("Unsupported inference inputs type: [%s]", inferenceInputs.getClass())); + private final boolean stream; + + public InferenceInputs(boolean stream) { + this.stream = stream; + } + + public static IllegalArgumentException createUnsupportedTypeException(InferenceInputs inferenceInputs, Class clazz) { + return new IllegalArgumentException( + Strings.format("Unable to convert inference inputs type: [%s] to [%s]", inferenceInputs.getClass(), clazz) + ); } + + public T castTo(Class clazz) { + if (clazz.isInstance(this) == false) { + throw createUnsupportedTypeException(this, clazz); + } + + return clazz.cast(this); + } + + public boolean stream() { + return stream; + } + + public abstract int inputSize(); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java index cea89332e5bf..4d730be6aa6b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiCompletionRequestManager.java @@ -15,7 +15,7 @@ import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; import org.elasticsearch.xpack.inference.external.openai.OpenAiChatCompletionResponseHandler; -import org.elasticsearch.xpack.inference.external.request.openai.OpenAiChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; @@ -25,8 +25,8 @@ public class OpenAiCompletionRequestManager extends OpenAiRequestManager { private static final Logger logger = LogManager.getLogger(OpenAiCompletionRequestManager.class); - private static final ResponseHandler HANDLER = createCompletionHandler(); + static final String USER_ROLE = "user"; public static OpenAiCompletionRequestManager of(OpenAiChatCompletionModel model, ThreadPool threadPool) { return new OpenAiCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); @@ -35,7 +35,7 @@ public static OpenAiCompletionRequestManager of(OpenAiChatCompletionModel model, private final OpenAiChatCompletionModel model; private OpenAiCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPool threadPool) { - super(threadPool, model, OpenAiChatCompletionRequest::buildDefaultUri); + super(threadPool, model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri); this.model = Objects.requireNonNull(model); } @@ -46,10 +46,8 @@ public void execute( Supplier hasRequestCompletedFunction, ActionListener listener ) { - var docsOnly = DocumentsOnlyInput.of(inferenceInputs); - var docsInput = docsOnly.getInputs(); - var stream = docsOnly.stream(); - OpenAiChatCompletionRequest request = new OpenAiChatCompletionRequest(docsInput, model, stream); + var chatCompletionInputs = inferenceInputs.castTo(ChatCompletionInput.class); + var request = new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(chatCompletionInputs, USER_ROLE), model); execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java new file mode 100644 index 000000000000..3b0f770e3e06 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/OpenAiUnifiedCompletionRequestManager.java @@ -0,0 +1,61 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.threadpool.ThreadPool; +import org.elasticsearch.xpack.inference.external.http.retry.RequestSender; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseHandler; +import org.elasticsearch.xpack.inference.external.openai.OpenAiUnifiedChatCompletionResponseHandler; +import org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest; +import org.elasticsearch.xpack.inference.external.response.openai.OpenAiChatCompletionResponseEntity; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; + +import java.util.Objects; +import java.util.function.Supplier; + +public class OpenAiUnifiedCompletionRequestManager extends OpenAiRequestManager { + + private static final Logger logger = LogManager.getLogger(OpenAiUnifiedCompletionRequestManager.class); + + private static final ResponseHandler HANDLER = createCompletionHandler(); + + public static OpenAiUnifiedCompletionRequestManager of(OpenAiChatCompletionModel model, ThreadPool threadPool) { + return new OpenAiUnifiedCompletionRequestManager(Objects.requireNonNull(model), Objects.requireNonNull(threadPool)); + } + + private final OpenAiChatCompletionModel model; + + private OpenAiUnifiedCompletionRequestManager(OpenAiChatCompletionModel model, ThreadPool threadPool) { + super(threadPool, model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri); + this.model = Objects.requireNonNull(model); + } + + @Override + public void execute( + InferenceInputs inferenceInputs, + RequestSender requestSender, + Supplier hasRequestCompletedFunction, + ActionListener listener + ) { + + OpenAiUnifiedChatCompletionRequest request = new OpenAiUnifiedChatCompletionRequest( + inferenceInputs.castTo(UnifiedChatInput.class), + model + ); + + execute(new ExecutableInferenceRequest(requestSender, logger, request, HANDLER, hasRequestCompletedFunction, listener)); + } + + private static ResponseHandler createCompletionHandler() { + return new OpenAiUnifiedChatCompletionResponseHandler("openai completion", OpenAiChatCompletionResponseEntity::fromResponse); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java index 50bb77b307db..5af5245ac5b4 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java @@ -14,7 +14,7 @@ public class QueryAndDocsInputs extends InferenceInputs { public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) { if (inferenceInputs instanceof QueryAndDocsInputs == false) { - throw createUnsupportedTypeException(inferenceInputs); + throw createUnsupportedTypeException(inferenceInputs, QueryAndDocsInputs.class); } return (QueryAndDocsInputs) inferenceInputs; @@ -22,17 +22,15 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) { private final String query; private final List chunks; - private final boolean stream; public QueryAndDocsInputs(String query, List chunks) { this(query, chunks, false); } public QueryAndDocsInputs(String query, List chunks, boolean stream) { - super(); + super(stream); this.query = Objects.requireNonNull(query); this.chunks = Objects.requireNonNull(chunks); - this.stream = stream; } public String getQuery() { @@ -43,8 +41,7 @@ public List getChunks() { return chunks; } - public boolean stream() { - return stream; + public int inputSize() { + return chunks.size(); } - } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java new file mode 100644 index 000000000000..f89fa1ee37a6 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInput.java @@ -0,0 +1,62 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.UnifiedCompletionRequest; + +import java.util.List; +import java.util.Objects; + +/** + * This class encapsulates the unified request. + * The main difference between this class and {@link ChatCompletionInput} is this should only be used for + * {@link org.elasticsearch.inference.TaskType#COMPLETION} originating through the + * {@link org.elasticsearch.inference.InferenceService#unifiedCompletionInfer(Model, UnifiedCompletionRequest, TimeValue, ActionListener)} + * code path. These are requests sent to the API with the _unified route. + */ +public class UnifiedChatInput extends InferenceInputs { + private final UnifiedCompletionRequest request; + + public UnifiedChatInput(UnifiedCompletionRequest request, boolean stream) { + super(stream); + this.request = Objects.requireNonNull(request); + } + + public UnifiedChatInput(ChatCompletionInput completionInput, String roleValue) { + this(completionInput.getInputs(), roleValue, completionInput.stream()); + } + + public UnifiedChatInput(List inputs, String roleValue, boolean stream) { + this(UnifiedCompletionRequest.of(convertToMessages(inputs, roleValue)), stream); + } + + private static List convertToMessages(List inputs, String roleValue) { + return inputs.stream() + .map( + value -> new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(value), + roleValue, + null, + null, + null + ) + ) + .toList(); + } + + public UnifiedCompletionRequest getRequest() { + return request; + } + + public int inputSize() { + return request.messages().size(); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java index 6e006fe25595..48c8132035b5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiStreamingProcessor.java @@ -18,10 +18,8 @@ import org.elasticsearch.xpack.core.inference.results.StreamingChatCompletionResults; import org.elasticsearch.xpack.inference.common.DelegatingProcessor; import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; -import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventField; import java.io.IOException; -import java.util.ArrayDeque; import java.util.Collections; import java.util.Deque; import java.util.Iterator; @@ -115,19 +113,7 @@ public class OpenAiStreamingProcessor extends DelegatingProcessor item) throws Exception { var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); - - var results = new ArrayDeque(item.size()); - for (ServerSentEvent event : item) { - if (ServerSentEventField.DATA == event.name() && event.hasValue()) { - try { - var delta = parse(parserConfig, event); - delta.forEachRemaining(results::offer); - } catch (Exception e) { - log.warn("Failed to parse event from inference provider: {}", event); - throw e; - } - } - } + var results = parseEvent(item, OpenAiStreamingProcessor::parse, parserConfig, log); if (results.isEmpty()) { upstream().request(1); @@ -136,7 +122,7 @@ protected void next(Deque item) throws Exception { } } - private Iterator parse(XContentParserConfiguration parserConfig, ServerSentEvent event) + private static Iterator parse(XContentParserConfiguration parserConfig, ServerSentEvent event) throws IOException { if (DONE_MESSAGE.equalsIgnoreCase(event.value())) { return Collections.emptyIterator(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionResponseHandler.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionResponseHandler.java new file mode 100644 index 000000000000..fce2556efc5e --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedChatCompletionResponseHandler.java @@ -0,0 +1,34 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.openai; + +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.external.http.HttpResult; +import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser; +import org.elasticsearch.xpack.inference.external.request.Request; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventParser; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEventProcessor; + +import java.util.concurrent.Flow; + +public class OpenAiUnifiedChatCompletionResponseHandler extends OpenAiChatCompletionResponseHandler { + public OpenAiUnifiedChatCompletionResponseHandler(String requestType, ResponseParser parseFunction) { + super(requestType, parseFunction); + } + + @Override + public InferenceServiceResults parseResult(Request request, Flow.Publisher flow) { + var serverSentEventProcessor = new ServerSentEventProcessor(new ServerSentEventParser()); + var openAiProcessor = new OpenAiUnifiedStreamingProcessor(); + + flow.subscribe(serverSentEventProcessor); + serverSentEventProcessor.subscribe(openAiProcessor); + return new StreamingUnifiedChatCompletionResults(openAiProcessor); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java new file mode 100644 index 000000000000..599d71df3dcf --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessor.java @@ -0,0 +1,287 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.openai; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; +import org.elasticsearch.xpack.inference.common.DelegatingProcessor; +import org.elasticsearch.xpack.inference.external.response.streaming.ServerSentEvent; + +import java.io.IOException; +import java.util.ArrayDeque; +import java.util.Collections; +import java.util.Deque; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.LinkedBlockingDeque; + +import static org.elasticsearch.common.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.elasticsearch.xpack.inference.external.response.XContentUtils.moveToFirstToken; + +public class OpenAiUnifiedStreamingProcessor extends DelegatingProcessor, ChunkedToXContent> { + public static final String FUNCTION_FIELD = "function"; + private static final Logger logger = LogManager.getLogger(OpenAiUnifiedStreamingProcessor.class); + + private static final String CHOICES_FIELD = "choices"; + private static final String DELTA_FIELD = "delta"; + private static final String CONTENT_FIELD = "content"; + private static final String DONE_MESSAGE = "[done]"; + private static final String REFUSAL_FIELD = "refusal"; + private static final String TOOL_CALLS_FIELD = "tool_calls"; + public static final String ROLE_FIELD = "role"; + public static final String FINISH_REASON_FIELD = "finish_reason"; + public static final String INDEX_FIELD = "index"; + public static final String OBJECT_FIELD = "object"; + public static final String MODEL_FIELD = "model"; + public static final String ID_FIELD = "id"; + public static final String CHOICE_FIELD = "choice"; + public static final String USAGE_FIELD = "usage"; + public static final String TYPE_FIELD = "type"; + public static final String NAME_FIELD = "name"; + public static final String ARGUMENTS_FIELD = "arguments"; + public static final String COMPLETION_TOKENS_FIELD = "completion_tokens"; + public static final String PROMPT_TOKENS_FIELD = "prompt_tokens"; + public static final String TOTAL_TOKENS_FIELD = "total_tokens"; + + private final Deque buffer = new LinkedBlockingDeque<>(); + + @Override + protected void upstreamRequest(long n) { + if (buffer.isEmpty()) { + super.upstreamRequest(n); + } else { + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(singleItem(buffer.poll()))); + } + } + + @Override + protected void next(Deque item) throws Exception { + var parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler(LoggingDeprecationHandler.INSTANCE); + var results = parseEvent(item, OpenAiUnifiedStreamingProcessor::parse, parserConfig, logger); + + if (results.isEmpty()) { + upstream().request(1); + } else if (results.size() == 1) { + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(results)); + } else { + // results > 1, but openai spec only wants 1 chunk per SSE event + var firstItem = singleItem(results.poll()); + while (results.isEmpty() == false) { + buffer.offer(results.poll()); + } + downstream().onNext(new StreamingUnifiedChatCompletionResults.Results(firstItem)); + } + } + + private static Iterator parse( + XContentParserConfiguration parserConfig, + ServerSentEvent event + ) throws IOException { + if (DONE_MESSAGE.equalsIgnoreCase(event.value())) { + return Collections.emptyIterator(); + } + + try (XContentParser jsonParser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, event.value())) { + moveToFirstToken(jsonParser); + + XContentParser.Token token = jsonParser.currentToken(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, token, jsonParser); + + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = ChatCompletionChunkParser.parse(jsonParser); + + return Collections.singleton(chunk).iterator(); + } + } + + public static class ChatCompletionChunkParser { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + "chat_completion_chunk", + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk( + (String) args[0], + (List) args[1], + (String) args[2], + (String) args[3], + (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage) args[4] + ) + ); + + static { + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(ID_FIELD)); + PARSER.declareObjectArray( + ConstructingObjectParser.constructorArg(), + (p, c) -> ChatCompletionChunkParser.ChoiceParser.parse(p), + new ParseField(CHOICES_FIELD) + ); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(MODEL_FIELD)); + PARSER.declareString(ConstructingObjectParser.constructorArg(), new ParseField(OBJECT_FIELD)); + PARSER.declareObjectOrNull( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> ChatCompletionChunkParser.UsageParser.parse(p), + null, + new ParseField(USAGE_FIELD) + ); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + + private static class ChoiceParser { + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + CHOICE_FIELD, + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice( + (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta) args[0], + (String) args[1], + (int) args[2] + ) + ); + + static { + PARSER.declareObject( + ConstructingObjectParser.constructorArg(), + (p, c) -> ChatCompletionChunkParser.DeltaParser.parse(p), + new ParseField(DELTA_FIELD) + ); + PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(FINISH_REASON_FIELD)); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(INDEX_FIELD)); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice parse(XContentParser parser) { + return PARSER.apply(parser, null); + } + } + + private static class DeltaParser { + @SuppressWarnings("unchecked") + private static final ConstructingObjectParser< + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta, + Void> PARSER = new ConstructingObjectParser<>( + DELTA_FIELD, + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta( + (String) args[0], + (String) args[1], + (String) args[2], + (List) args[3] + ) + ); + + static { + PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(CONTENT_FIELD)); + PARSER.declareStringOrNull(ConstructingObjectParser.optionalConstructorArg(), new ParseField(REFUSAL_FIELD)); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ROLE_FIELD)); + PARSER.declareObjectArray( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> ChatCompletionChunkParser.ToolCallParser.parse(p), + new ParseField(TOOL_CALLS_FIELD) + ); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta parse(XContentParser parser) + throws IOException { + return PARSER.parse(parser, null); + } + } + + private static class ToolCallParser { + private static final ConstructingObjectParser< + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall, + Void> PARSER = new ConstructingObjectParser<>( + "tool_call", + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall( + (int) args[0], + (String) args[1], + (StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function) args[2], + (String) args[3] + ) + ); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(INDEX_FIELD)); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ID_FIELD)); + PARSER.declareObject( + ConstructingObjectParser.optionalConstructorArg(), + (p, c) -> ChatCompletionChunkParser.FunctionParser.parse(p), + new ParseField(FUNCTION_FIELD) + ); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(TYPE_FIELD)); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall parse(XContentParser parser) + throws IOException { + return PARSER.parse(parser, null); + } + } + + private static class FunctionParser { + private static final ConstructingObjectParser< + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function, + Void> PARSER = new ConstructingObjectParser<>( + FUNCTION_FIELD, + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function( + (String) args[0], + (String) args[1] + ) + ); + + static { + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(ARGUMENTS_FIELD)); + PARSER.declareString(ConstructingObjectParser.optionalConstructorArg(), new ParseField(NAME_FIELD)); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall.Function parse( + XContentParser parser + ) throws IOException { + return PARSER.parse(parser, null); + } + } + + private static class UsageParser { + private static final ConstructingObjectParser PARSER = + new ConstructingObjectParser<>( + USAGE_FIELD, + true, + args -> new StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage((int) args[0], (int) args[1], (int) args[2]) + ); + + static { + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(COMPLETION_TOKENS_FIELD)); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(PROMPT_TOKENS_FIELD)); + PARSER.declareInt(ConstructingObjectParser.constructorArg(), new ParseField(TOTAL_TOKENS_FIELD)); + } + + public static StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Usage parse(XContentParser parser) throws IOException { + return PARSER.parse(parser, null); + } + } + } + + private Deque singleItem( + StreamingUnifiedChatCompletionResults.ChatCompletionChunk result + ) { + var deque = new ArrayDeque(1); + deque.offer(result); + return deque; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java index 80770d63ef13..b1af18d03dda 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/GoogleAiStudioCompletionRequest.java @@ -14,7 +14,7 @@ import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.util.LazyInitializable; import org.elasticsearch.xcontent.XContentType; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModel; @@ -27,13 +27,13 @@ public class GoogleAiStudioCompletionRequest implements GoogleAiStudioRequest { private static final String ALT_PARAM = "alt"; private static final String SSE_VALUE = "sse"; - private final DocumentsOnlyInput input; + private final ChatCompletionInput input; private final LazyInitializable uri; private final GoogleAiStudioCompletionModel model; - public GoogleAiStudioCompletionRequest(DocumentsOnlyInput input, GoogleAiStudioCompletionModel model) { + public GoogleAiStudioCompletionRequest(ChatCompletionInput input, GoogleAiStudioCompletionModel model) { this.input = Objects.requireNonNull(input); this.model = Objects.requireNonNull(model); this.uri = new LazyInitializable<>(() -> model.uri(input.stream())); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java deleted file mode 100644 index 867a7ca80cbc..000000000000 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntity.java +++ /dev/null @@ -1,79 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.openai; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.xcontent.ToXContentObject; -import org.elasticsearch.xcontent.XContentBuilder; - -import java.io.IOException; -import java.util.List; -import java.util.Objects; - -public class OpenAiChatCompletionRequestEntity implements ToXContentObject { - - private static final String MESSAGES_FIELD = "messages"; - private static final String MODEL_FIELD = "model"; - - private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; - - private static final String ROLE_FIELD = "role"; - private static final String USER_FIELD = "user"; - private static final String CONTENT_FIELD = "content"; - private static final String STREAM_FIELD = "stream"; - - private final List messages; - private final String model; - - private final String user; - private final boolean stream; - - public OpenAiChatCompletionRequestEntity(List messages, String model, String user, boolean stream) { - Objects.requireNonNull(messages); - Objects.requireNonNull(model); - - this.messages = messages; - this.model = model; - this.user = user; - this.stream = stream; - } - - @Override - public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { - builder.startObject(); - builder.startArray(MESSAGES_FIELD); - { - for (String message : messages) { - builder.startObject(); - - { - builder.field(ROLE_FIELD, USER_FIELD); - builder.field(CONTENT_FIELD, message); - } - - builder.endObject(); - } - } - builder.endArray(); - - builder.field(MODEL_FIELD, model); - builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); - - if (Strings.isNullOrEmpty(user) == false) { - builder.field(USER_FIELD, user); - } - - if (stream) { - builder.field(STREAM_FIELD, true); - } - - builder.endObject(); - - return builder; - } -} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java similarity index 80% rename from x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java rename to x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java index 99a025e70d00..2e6bdb748fd3 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequest.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequest.java @@ -13,6 +13,7 @@ import org.apache.http.entity.ByteArrayEntity; import org.elasticsearch.common.Strings; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.openai.OpenAiAccount; import org.elasticsearch.xpack.inference.external.request.HttpRequest; import org.elasticsearch.xpack.inference.external.request.Request; @@ -21,24 +22,21 @@ import java.net.URI; import java.net.URISyntaxException; import java.nio.charset.StandardCharsets; -import java.util.List; import java.util.Objects; import static org.elasticsearch.xpack.inference.external.request.RequestUtils.createAuthBearerHeader; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.createOrgHeader; -public class OpenAiChatCompletionRequest implements OpenAiRequest { +public class OpenAiUnifiedChatCompletionRequest implements OpenAiRequest { private final OpenAiAccount account; - private final List input; private final OpenAiChatCompletionModel model; - private final boolean stream; + private final UnifiedChatInput unifiedChatInput; - public OpenAiChatCompletionRequest(List input, OpenAiChatCompletionModel model, boolean stream) { - this.account = OpenAiAccount.of(model, OpenAiChatCompletionRequest::buildDefaultUri); - this.input = Objects.requireNonNull(input); + public OpenAiUnifiedChatCompletionRequest(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) { + this.account = OpenAiAccount.of(model, OpenAiUnifiedChatCompletionRequest::buildDefaultUri); + this.unifiedChatInput = Objects.requireNonNull(unifiedChatInput); this.model = Objects.requireNonNull(model); - this.stream = stream; } @Override @@ -46,9 +44,7 @@ public HttpRequest createHttpRequest() { HttpPost httpPost = new HttpPost(account.uri()); ByteArrayEntity byteEntity = new ByteArrayEntity( - Strings.toString( - new OpenAiChatCompletionRequestEntity(input, model.getServiceSettings().modelId(), model.getTaskSettings().user(), stream) - ).getBytes(StandardCharsets.UTF_8) + Strings.toString(new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model)).getBytes(StandardCharsets.UTF_8) ); httpPost.setEntity(byteEntity); @@ -87,7 +83,7 @@ public String getInferenceEntityId() { @Override public boolean isStreaming() { - return stream; + return unifiedChatInput.stream(); } public static URI buildDefaultUri() throws URISyntaxException { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java new file mode 100644 index 000000000000..50339bf851f7 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntity.java @@ -0,0 +1,185 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.openai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.xcontent.ToXContentObject; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; + +import java.io.IOException; +import java.util.Objects; + +public class OpenAiUnifiedChatCompletionRequestEntity implements ToXContentObject { + + public static final String NAME_FIELD = "name"; + public static final String TOOL_CALL_ID_FIELD = "tool_call_id"; + public static final String TOOL_CALLS_FIELD = "tool_calls"; + public static final String ID_FIELD = "id"; + public static final String FUNCTION_FIELD = "function"; + public static final String ARGUMENTS_FIELD = "arguments"; + public static final String DESCRIPTION_FIELD = "description"; + public static final String PARAMETERS_FIELD = "parameters"; + public static final String STRICT_FIELD = "strict"; + public static final String TOP_P_FIELD = "top_p"; + public static final String USER_FIELD = "user"; + public static final String STREAM_FIELD = "stream"; + private static final String NUMBER_OF_RETURNED_CHOICES_FIELD = "n"; + private static final String MODEL_FIELD = "model"; + public static final String MESSAGES_FIELD = "messages"; + private static final String ROLE_FIELD = "role"; + private static final String CONTENT_FIELD = "content"; + private static final String MAX_COMPLETION_TOKENS_FIELD = "max_completion_tokens"; + private static final String STOP_FIELD = "stop"; + private static final String TEMPERATURE_FIELD = "temperature"; + private static final String TOOL_CHOICE_FIELD = "tool_choice"; + private static final String TOOL_FIELD = "tools"; + private static final String TEXT_FIELD = "text"; + private static final String TYPE_FIELD = "type"; + private static final String STREAM_OPTIONS_FIELD = "stream_options"; + private static final String INCLUDE_USAGE_FIELD = "include_usage"; + + private final UnifiedCompletionRequest unifiedRequest; + private final boolean stream; + private final OpenAiChatCompletionModel model; + + public OpenAiUnifiedChatCompletionRequestEntity(UnifiedChatInput unifiedChatInput, OpenAiChatCompletionModel model) { + Objects.requireNonNull(unifiedChatInput); + + this.unifiedRequest = unifiedChatInput.getRequest(); + this.stream = unifiedChatInput.stream(); + this.model = Objects.requireNonNull(model); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject(); + builder.startArray(MESSAGES_FIELD); + { + for (UnifiedCompletionRequest.Message message : unifiedRequest.messages()) { + builder.startObject(); + { + switch (message.content()) { + case UnifiedCompletionRequest.ContentString contentString -> builder.field(CONTENT_FIELD, contentString.content()); + case UnifiedCompletionRequest.ContentObjects contentObjects -> { + builder.startArray(CONTENT_FIELD); + for (UnifiedCompletionRequest.ContentObject contentObject : contentObjects.contentObjects()) { + builder.startObject(); + builder.field(TEXT_FIELD, contentObject.text()); + builder.field(TYPE_FIELD, contentObject.type()); + builder.endObject(); + } + builder.endArray(); + } + } + + builder.field(ROLE_FIELD, message.role()); + if (message.name() != null) { + builder.field(NAME_FIELD, message.name()); + } + if (message.toolCallId() != null) { + builder.field(TOOL_CALL_ID_FIELD, message.toolCallId()); + } + if (message.toolCalls() != null) { + builder.startArray(TOOL_CALLS_FIELD); + for (UnifiedCompletionRequest.ToolCall toolCall : message.toolCalls()) { + builder.startObject(); + { + builder.field(ID_FIELD, toolCall.id()); + builder.startObject(FUNCTION_FIELD); + { + builder.field(ARGUMENTS_FIELD, toolCall.function().arguments()); + builder.field(NAME_FIELD, toolCall.function().name()); + } + builder.endObject(); + builder.field(TYPE_FIELD, toolCall.type()); + } + builder.endObject(); + } + builder.endArray(); + } + } + builder.endObject(); + } + } + builder.endArray(); + + builder.field(MODEL_FIELD, model.getServiceSettings().modelId()); + if (unifiedRequest.maxCompletionTokens() != null) { + builder.field(MAX_COMPLETION_TOKENS_FIELD, unifiedRequest.maxCompletionTokens()); + } + + builder.field(NUMBER_OF_RETURNED_CHOICES_FIELD, 1); + + if (unifiedRequest.stop() != null && unifiedRequest.stop().isEmpty() == false) { + builder.field(STOP_FIELD, unifiedRequest.stop()); + } + if (unifiedRequest.temperature() != null) { + builder.field(TEMPERATURE_FIELD, unifiedRequest.temperature()); + } + if (unifiedRequest.toolChoice() != null) { + if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceString) { + builder.field(TOOL_CHOICE_FIELD, ((UnifiedCompletionRequest.ToolChoiceString) unifiedRequest.toolChoice()).value()); + } else if (unifiedRequest.toolChoice() instanceof UnifiedCompletionRequest.ToolChoiceObject) { + builder.startObject(TOOL_CHOICE_FIELD); + { + builder.field(TYPE_FIELD, ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).type()); + builder.startObject(FUNCTION_FIELD); + { + builder.field( + NAME_FIELD, + ((UnifiedCompletionRequest.ToolChoiceObject) unifiedRequest.toolChoice()).function().name() + ); + } + builder.endObject(); + } + builder.endObject(); + } + } + if (unifiedRequest.tools() != null && unifiedRequest.tools().isEmpty() == false) { + builder.startArray(TOOL_FIELD); + for (UnifiedCompletionRequest.Tool t : unifiedRequest.tools()) { + builder.startObject(); + { + builder.field(TYPE_FIELD, t.type()); + builder.startObject(FUNCTION_FIELD); + { + builder.field(DESCRIPTION_FIELD, t.function().description()); + builder.field(NAME_FIELD, t.function().name()); + builder.field(PARAMETERS_FIELD, t.function().parameters()); + if (t.function().strict() != null) { + builder.field(STRICT_FIELD, t.function().strict()); + } + } + builder.endObject(); + } + builder.endObject(); + } + builder.endArray(); + } + if (unifiedRequest.topP() != null) { + builder.field(TOP_P_FIELD, unifiedRequest.topP()); + } + + if (Strings.isNullOrEmpty(model.getTaskSettings().user()) == false) { + builder.field(USER_FIELD, model.getTaskSettings().user()); + } + + builder.field(STREAM_FIELD, stream); + if (stream) { + builder.startObject(STREAM_OPTIONS_FIELD); + builder.field(INCLUDE_USAGE_FIELD, true); + builder.endObject(); + } + builder.endObject(); + + return builder; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java new file mode 100644 index 000000000000..f2bfa72ec617 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighter.java @@ -0,0 +1,226 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.highlight; + +import org.apache.lucene.index.LeafReader; +import org.apache.lucene.index.Term; +import org.apache.lucene.search.BooleanClause; +import org.apache.lucene.search.BooleanQuery; +import org.apache.lucene.search.DocIdSetIterator; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.KnnByteVectorQuery; +import org.apache.lucene.search.KnnFloatVectorQuery; +import org.apache.lucene.search.Query; +import org.apache.lucene.search.QueryVisitor; +import org.apache.lucene.search.ScoreMode; +import org.apache.lucene.search.Scorer; +import org.apache.lucene.search.Weight; +import org.elasticsearch.common.text.Text; +import org.elasticsearch.common.xcontent.support.XContentMapValues; +import org.elasticsearch.index.mapper.MappedFieldType; +import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper.DenseVectorFieldType; +import org.elasticsearch.index.mapper.vectors.SparseVectorFieldMapper.SparseVectorFieldType; +import org.elasticsearch.index.query.SearchExecutionContext; +import org.elasticsearch.search.fetch.subphase.highlight.FieldHighlightContext; +import org.elasticsearch.search.fetch.subphase.highlight.HighlightField; +import org.elasticsearch.search.fetch.subphase.highlight.Highlighter; +import org.elasticsearch.search.vectors.VectorData; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper; +import org.elasticsearch.xpack.inference.mapper.SemanticTextField; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.List; +import java.util.Locale; +import java.util.Map; + +/** + * A {@link Highlighter} designed for the {@link SemanticTextFieldMapper}. + * This highlighter extracts semantic queries and evaluates them against each chunk produced by the semantic text field. + * It returns the top-scoring chunks as snippets, optionally sorted by their scores. + */ +public class SemanticTextHighlighter implements Highlighter { + public static final String NAME = "semantic"; + + private record OffsetAndScore(int offset, float score) {} + + @Override + public boolean canHighlight(MappedFieldType fieldType) { + if (fieldType instanceof SemanticTextFieldMapper.SemanticTextFieldType) { + return true; + } + return false; + } + + @Override + public HighlightField highlight(FieldHighlightContext fieldContext) throws IOException { + SemanticTextFieldMapper.SemanticTextFieldType fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) fieldContext.fieldType; + if (fieldType.getEmbeddingsField() == null) { + // nothing indexed yet + return null; + } + + final List queries = switch (fieldType.getModelSettings().taskType()) { + case SPARSE_EMBEDDING -> extractSparseVectorQueries( + (SparseVectorFieldType) fieldType.getEmbeddingsField().fieldType(), + fieldContext.query + ); + case TEXT_EMBEDDING -> extractDenseVectorQueries( + (DenseVectorFieldType) fieldType.getEmbeddingsField().fieldType(), + fieldContext.query + ); + default -> throw new IllegalStateException( + "Wrong task type for a semantic text field, got [" + fieldType.getModelSettings().taskType().name() + "]" + ); + }; + if (queries.isEmpty()) { + // nothing to highlight + return null; + } + + int numberOfFragments = fieldContext.field.fieldOptions().numberOfFragments() <= 0 + ? 1 // we return the best fragment by default + : fieldContext.field.fieldOptions().numberOfFragments(); + + List chunks = extractOffsetAndScores( + fieldContext.context.getSearchExecutionContext(), + fieldContext.hitContext.reader(), + fieldType, + fieldContext.hitContext.docId(), + queries + ); + if (chunks.size() == 0) { + return null; + } + + chunks.sort(Comparator.comparingDouble(OffsetAndScore::score).reversed()); + int size = Math.min(chunks.size(), numberOfFragments); + if (fieldContext.field.fieldOptions().scoreOrdered() == false) { + chunks = chunks.subList(0, size); + chunks.sort(Comparator.comparingInt(c -> c.offset)); + } + Text[] snippets = new Text[size]; + List> nestedSources = XContentMapValues.extractNestedSources( + fieldType.getChunksField().fullPath(), + fieldContext.hitContext.source().source() + ); + for (int i = 0; i < size; i++) { + var chunk = chunks.get(i); + if (nestedSources.size() <= chunk.offset) { + throw new IllegalStateException( + String.format( + Locale.ROOT, + "Invalid content detected for field [%s]: the chunks size is [%d], " + + "but a reference to offset [%d] was found in the result.", + fieldType.name(), + nestedSources.size(), + chunk.offset + ) + ); + } + String content = (String) nestedSources.get(chunk.offset).get(SemanticTextField.CHUNKED_TEXT_FIELD); + if (content == null) { + throw new IllegalStateException( + String.format( + Locale.ROOT, + + "Invalid content detected for field [%s]: missing text for the chunk at offset [%d].", + fieldType.name(), + chunk.offset + ) + ); + } + snippets[i] = new Text(content); + } + return new HighlightField(fieldContext.fieldName, snippets); + } + + private List extractOffsetAndScores( + SearchExecutionContext context, + LeafReader reader, + SemanticTextFieldMapper.SemanticTextFieldType fieldType, + int docId, + List leafQueries + ) throws IOException { + var bitSet = context.bitsetFilter(fieldType.getChunksField().parentTypeFilter()).getBitSet(reader.getContext()); + int previousParent = docId > 0 ? bitSet.prevSetBit(docId - 1) : -1; + + BooleanQuery.Builder bq = new BooleanQuery.Builder().add(fieldType.getChunksField().nestedTypeFilter(), BooleanClause.Occur.FILTER); + leafQueries.stream().forEach(q -> bq.add(q, BooleanClause.Occur.SHOULD)); + Weight weight = new IndexSearcher(reader).createWeight(bq.build(), ScoreMode.COMPLETE, 1); + Scorer scorer = weight.scorer(reader.getContext()); + if (previousParent != -1) { + if (scorer.iterator().advance(previousParent) == DocIdSetIterator.NO_MORE_DOCS) { + return List.of(); + } + } else if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) { + return List.of(); + } + List results = new ArrayList<>(); + int offset = 0; + while (scorer.docID() < docId) { + results.add(new OffsetAndScore(offset++, scorer.score())); + if (scorer.iterator().nextDoc() == DocIdSetIterator.NO_MORE_DOCS) { + break; + } + } + return results; + } + + private List extractDenseVectorQueries(DenseVectorFieldType fieldType, Query querySection) { + // TODO: Handle knn section when semantic text field can be used. + List queries = new ArrayList<>(); + querySection.visit(new QueryVisitor() { + @Override + public boolean acceptField(String field) { + return fieldType.name().equals(field); + } + + @Override + public void consumeTerms(Query query, Term... terms) { + super.consumeTerms(query, terms); + } + + @Override + public void visitLeaf(Query query) { + if (query instanceof KnnFloatVectorQuery knnQuery) { + queries.add(fieldType.createExactKnnQuery(VectorData.fromFloats(knnQuery.getTargetCopy()), null)); + } else if (query instanceof KnnByteVectorQuery knnQuery) { + queries.add(fieldType.createExactKnnQuery(VectorData.fromBytes(knnQuery.getTargetCopy()), null)); + } + } + }); + return queries; + } + + private List extractSparseVectorQueries(SparseVectorFieldType fieldType, Query querySection) { + List queries = new ArrayList<>(); + querySection.visit(new QueryVisitor() { + @Override + public boolean acceptField(String field) { + return fieldType.name().equals(field); + } + + @Override + public void consumeTerms(Query query, Term... terms) { + super.consumeTerms(query, terms); + } + + @Override + public QueryVisitor getSubVisitor(BooleanClause.Occur occur, Query parent) { + if (parent instanceof SparseVectorQueryWrapper sparseVectorQuery) { + queries.add(sparseVectorQuery.getTermsQuery()); + } + return this; + } + }); + return queries; + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java index e60e95b58770..0f26f6577860 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextField.java @@ -61,7 +61,7 @@ public record SemanticTextField(String fieldName, List originalValues, I static final String SEARCH_INFERENCE_ID_FIELD = "search_inference_id"; static final String CHUNKS_FIELD = "chunks"; static final String CHUNKED_EMBEDDINGS_FIELD = "embeddings"; - static final String CHUNKED_TEXT_FIELD = "text"; + public static final String CHUNKED_TEXT_FIELD = "text"; static final String MODEL_SETTINGS_FIELD = "model_settings"; static final String TASK_TYPE_FIELD = "task_type"; static final String DIMENSIONS_FIELD = "dimensions"; diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java index 3744bf2a6dbe..683bb5a53028 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapper.java @@ -46,7 +46,6 @@ import org.elasticsearch.index.query.MatchNoneQueryBuilder; import org.elasticsearch.index.query.NestedQueryBuilder; import org.elasticsearch.index.query.QueryBuilder; -import org.elasticsearch.index.query.QueryBuilders; import org.elasticsearch.index.query.SearchExecutionContext; import org.elasticsearch.inference.InferenceResults; import org.elasticsearch.inference.SimilarityMeasure; @@ -57,6 +56,7 @@ import org.elasticsearch.xcontent.XContentParserConfiguration; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; import java.io.IOException; import java.util.ArrayList; @@ -529,17 +529,15 @@ public QueryBuilder semanticQuery(InferenceResults inferenceResults, Integer req ); } - // TODO: Use WeightedTokensQueryBuilder TextExpansionResults textExpansionResults = (TextExpansionResults) inferenceResults; - var boolQuery = QueryBuilders.boolQuery(); - for (var weightedToken : textExpansionResults.getWeightedTokens()) { - boolQuery.should( - QueryBuilders.termQuery(inferenceResultsFieldName, weightedToken.token()).boost(weightedToken.weight()) - ); - } - boolQuery.minimumShouldMatch(1); - - yield boolQuery; + yield new SparseVectorQueryBuilder( + inferenceResultsFieldName, + textExpansionResults.getWeightedTokens(), + null, + null, + null, + null + ); } case TEXT_EMBEDDING -> { if (inferenceResults instanceof MlTextEmbeddingResults == false) { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java index e72e68052f64..d911158e8229 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/BaseInferenceAction.java @@ -9,6 +9,7 @@ import org.elasticsearch.action.ActionListener; import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.BaseRestHandler; import org.elasticsearch.rest.RestChannel; @@ -21,27 +22,32 @@ import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID; abstract class BaseInferenceAction extends BaseRestHandler { - @Override - protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { - String inferenceEntityId; - TaskType taskType; + static Params parseParams(RestRequest restRequest) { if (restRequest.hasParam(INFERENCE_ID)) { - inferenceEntityId = restRequest.param(INFERENCE_ID); - taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID)); + var inferenceEntityId = restRequest.param(INFERENCE_ID); + var taskType = TaskType.fromStringOrStatusException(restRequest.param(TASK_TYPE_OR_INFERENCE_ID)); + return new Params(inferenceEntityId, taskType); } else { - inferenceEntityId = restRequest.param(TASK_TYPE_OR_INFERENCE_ID); - taskType = TaskType.ANY; + return new Params(restRequest.param(TASK_TYPE_OR_INFERENCE_ID), TaskType.ANY); } + } + + record Params(String inferenceEntityId, TaskType taskType) {} + + static TimeValue parseTimeout(RestRequest restRequest) { + return restRequest.paramAsTime(InferenceAction.Request.TIMEOUT.getPreferredName(), InferenceAction.Request.DEFAULT_TIMEOUT); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + var params = parseParams(restRequest); InferenceAction.Request.Builder requestBuilder; try (var parser = restRequest.contentParser()) { - requestBuilder = InferenceAction.Request.parseRequest(inferenceEntityId, taskType, parser); + requestBuilder = InferenceAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), parser); } - var inferTimeout = restRequest.paramAsTime( - InferenceAction.Request.TIMEOUT.getPreferredName(), - InferenceAction.Request.DEFAULT_TIMEOUT - ); + var inferTimeout = parseTimeout(restRequest); requestBuilder.setInferenceTimeout(inferTimeout); var request = prepareInferenceRequest(requestBuilder); return channel -> client.execute(InferenceAction.INSTANCE, request, listener(channel)); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java index 55d6443b43c0..c46f211bb26a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/Paths.java @@ -30,6 +30,12 @@ public final class Paths { + "}/{" + INFERENCE_ID + "}/_stream"; + static final String UNIFIED_INFERENCE_ID_PATH = "_inference/{" + TASK_TYPE_OR_INFERENCE_ID + "}/_unified"; + static final String UNIFIED_TASK_TYPE_INFERENCE_ID_PATH = "_inference/{" + + TASK_TYPE_OR_INFERENCE_ID + + "}/{" + + INFERENCE_ID + + "}/_unified"; private Paths() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java new file mode 100644 index 000000000000..5c71b560a6b9 --- /dev/null +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceAction.java @@ -0,0 +1,49 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.Scope; +import org.elasticsearch.rest.ServerlessScope; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_INFERENCE_ID_PATH; +import static org.elasticsearch.xpack.inference.rest.Paths.UNIFIED_TASK_TYPE_INFERENCE_ID_PATH; + +@ServerlessScope(Scope.PUBLIC) +public class RestUnifiedCompletionInferenceAction extends BaseRestHandler { + @Override + public String getName() { + return "unified_inference_action"; + } + + @Override + public List routes() { + return List.of(new Route(POST, UNIFIED_INFERENCE_ID_PATH), new Route(POST, UNIFIED_TASK_TYPE_INFERENCE_ID_PATH)); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest restRequest, NodeClient client) throws IOException { + var params = BaseInferenceAction.parseParams(restRequest); + + var inferTimeout = BaseInferenceAction.parseTimeout(restRequest); + + UnifiedCompletionAction.Request request; + try (var parser = restRequest.contentParser()) { + request = UnifiedCompletionAction.Request.parseRequest(params.inferenceEntityId(), params.taskType(), inferTimeout, parser); + } + + return channel -> client.execute(UnifiedCompletionAction.INSTANCE, request, new ServerSentEventsRestActionListener(channel)); + } +} diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 8e2dac1ef9db..e9b75e9ec779 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -7,9 +7,11 @@ package org.elasticsearch.xpack.inference.services; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.core.IOUtils; import org.elasticsearch.core.Nullable; +import org.elasticsearch.core.Strings; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.ChunkedInferenceServiceResults; import org.elasticsearch.inference.InferenceService; @@ -17,11 +19,15 @@ import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.QueryAndDocsInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import java.io.IOException; import java.util.EnumSet; @@ -61,11 +67,31 @@ public void infer( ActionListener listener ) { init(); - if (query != null) { - doInfer(model, new QueryAndDocsInputs(query, input, stream), taskSettings, inputType, timeout, listener); - } else { - doInfer(model, new DocumentsOnlyInput(input, stream), taskSettings, inputType, timeout, listener); - } + var inferenceInput = createInput(model, input, query, stream); + doInfer(model, inferenceInput, taskSettings, inputType, timeout, listener); + } + + private static InferenceInputs createInput(Model model, List input, @Nullable String query, boolean stream) { + return switch (model.getTaskType()) { + case COMPLETION -> new ChatCompletionInput(input, stream); + case RERANK -> new QueryAndDocsInputs(query, input, stream); + case TEXT_EMBEDDING, SPARSE_EMBEDDING -> new DocumentsOnlyInput(input, stream); + default -> throw new ElasticsearchStatusException( + Strings.format("Invalid task type received when determining input type: [%s]", model.getTaskType().toString()), + RestStatus.BAD_REQUEST + ); + }; + } + + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + init(); + doUnifiedCompletionInfer(model, new UnifiedChatInput(request, true), timeout, listener); } @Override @@ -92,6 +118,13 @@ protected abstract void doInfer( ActionListener listener ); + protected abstract void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ); + protected abstract void doChunkedInfer( Model model, DocumentsOnlyInput inputs, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index ec4b8d9bb4d3..7d05bac363fb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -776,5 +776,9 @@ public static T nonNullOrDefault(@Nullable T requestValue, @Nullable T origi return requestValue == null ? originalSettingsValue : requestValue; } + public static void throwUnsupportedUnifiedCompletionOperation(String serviceName) { + throw new UnsupportedOperationException(Strings.format("The %s service does not support unified completion", serviceName)); + } + private ServiceUtils() {} } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 5adc2a11b19d..ffd26b9ac534 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -37,6 +37,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.external.request.alibabacloudsearch.AlibabaCloudSearchUtils; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; @@ -57,14 +58,13 @@ import java.util.Map; import java.util.stream.Stream; -import static org.elasticsearch.inference.TaskType.SPARSE_EMBEDDING; -import static org.elasticsearch.inference.TaskType.TEXT_EMBEDDING; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; import static org.elasticsearch.xpack.inference.services.ServiceUtils.parsePersistedConfigErrorMsg; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceFields.EMBEDDING_MAX_BATCH_SIZE; import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings.HOST; import static org.elasticsearch.xpack.inference.services.alibabacloudsearch.AlibabaCloudSearchServiceSettings.HTTP_SCHEMA_NAME; @@ -261,6 +261,16 @@ public AlibabaCloudSearchModel parsePersistedConfig(String inferenceEntityId, Ta ); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java index 48b3c3df03e1..d224e50bb650 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockService.java @@ -40,6 +40,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -64,6 +65,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.MODEL_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.PROVIDER_FIELD; import static org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockConstants.REGION_FIELD; @@ -89,6 +91,16 @@ public AmazonBedrockService( this.amazonBedrockSender = amazonBedrockFactory.createSender(); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java index b3d503de8e3e..f1840af18779 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicService.java @@ -32,6 +32,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -52,6 +53,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; public class AnthropicService extends SenderService { public static final String NAME = "anthropic"; @@ -192,6 +194,16 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java index bba331fc0b5d..f8ea11e4b15a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioService.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -63,6 +64,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.ENDPOINT_TYPE_FIELD; import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.PROVIDER_FIELD; import static org.elasticsearch.xpack.inference.services.azureaistudio.AzureAiStudioConstants.TARGET_FIELD; @@ -81,6 +83,16 @@ public AzureAiStudioService(HttpRequestSender.Factory factory, ServiceComponents super(factory, serviceComponents); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java index 16c94dfa9ad9..a38c265d2613 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiService.java @@ -36,6 +36,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -58,6 +59,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.API_VERSION; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.DEPLOYMENT_ID; import static org.elasticsearch.xpack.inference.services.azureopenai.AzureOpenAiServiceFields.RESOURCE_NAME; @@ -233,6 +235,16 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java index b3d8b3b6efce..ccb8d79dacd6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/cohere/CohereService.java @@ -34,6 +34,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -58,6 +59,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.cohere.CohereServiceFields.EMBEDDING_MAX_BATCH_SIZE; public class CohereService extends SenderService { @@ -232,6 +234,16 @@ public EnumSet supportedTaskTypes() { return supportedTaskTypes; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java index b256861e7dd2..fe8ee52eb881 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceService.java @@ -38,6 +38,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -57,6 +58,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; public class ElasticInferenceService extends SenderService { @@ -76,6 +78,16 @@ public ElasticInferenceService( this.elasticInferenceServiceComponents = elasticInferenceServiceComponents; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index 0e64842f873d..5f613d6be586 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -31,6 +31,7 @@ import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskSettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.inference.configuration.SettingsConfigurationSelectOption; @@ -77,6 +78,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.MODEL_ID; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_ALLOCATIONS; import static org.elasticsearch.xpack.inference.services.elasticsearch.ElasticsearchInternalServiceSettings.NUM_THREADS; @@ -578,6 +580,16 @@ private static CustomElandEmbeddingModel updateModelWithEmbeddingDetails(CustomE ); } + @Override + public void unifiedCompletionInfer( + Model model, + UnifiedCompletionRequest request, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public void infer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java index 57a8a66a3f3a..b681722a8213 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioService.java @@ -39,6 +39,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.GoogleAiStudioEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -64,6 +65,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.googleaistudio.GoogleAiStudioServiceFields.EMBEDDING_MAX_BATCH_SIZE; public class GoogleAiStudioService extends SenderService { @@ -282,9 +284,8 @@ protected void doInfer( ) { if (model instanceof GoogleAiStudioCompletionModel completionModel) { var requestManager = new GoogleAiStudioCompletionRequestManager(completionModel, getServiceComponents().threadPool()); - var docsOnly = DocumentsOnlyInput.of(inputs); var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage( - completionModel.uri(docsOnly.stream()), + completionModel.uri(inputs.stream()), "Google AI Studio completion" ); var action = new SingleInputSenderExecutableAction( @@ -308,6 +309,16 @@ protected void doInfer( } } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java index 857d475499aa..87a2d98dca92 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/googlevertexai/GoogleVertexAiService.java @@ -35,6 +35,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -57,6 +58,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.EMBEDDING_MAX_BATCH_SIZE; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.LOCATION; import static org.elasticsearch.xpack.inference.services.googlevertexai.GoogleVertexAiServiceFields.PROJECT_ID; @@ -206,6 +208,16 @@ protected void doInfer( action.execute(inputs, timeout, listener); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java index 51cca72f2605..b74ec01cd76e 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceService.java @@ -18,6 +18,7 @@ import org.elasticsearch.inference.ChunkingSettings; import org.elasticsearch.inference.EmptySettingsConfiguration; import org.elasticsearch.inference.InferenceServiceConfiguration; +import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SettingsConfiguration; @@ -31,6 +32,7 @@ import org.elasticsearch.xpack.inference.external.action.huggingface.HuggingFaceActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; @@ -47,6 +49,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; public class HuggingFaceService extends HuggingFaceBaseService { public static final String NAME = "hugging_face"; @@ -139,6 +142,16 @@ protected void doChunkedInfer( } } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override public InferenceServiceConfiguration getConfiguration() { return Configuration.get(); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java index 75920efa251f..5b038781b96a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/huggingface/elser/HuggingFaceElserService.java @@ -36,6 +36,7 @@ import org.elasticsearch.xpack.core.ml.inference.results.ErrorInferenceResults; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.huggingface.HuggingFaceBaseService; @@ -49,6 +50,7 @@ import java.util.Map; import static org.elasticsearch.xpack.core.inference.results.ResultUtils.createInvalidChunkedResultException; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL; public class HuggingFaceElserService extends HuggingFaceBaseService { @@ -81,6 +83,16 @@ protected HuggingFaceModel createModel( }; } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java index 981a3e95808e..cc66d5fd7ee7 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxService.java @@ -37,6 +37,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -57,6 +58,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.huggingface.elser.HuggingFaceElserServiceSettings.URL; import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.API_VERSION; import static org.elasticsearch.xpack.inference.services.ibmwatsonx.IbmWatsonxServiceFields.EMBEDDING_MAX_BATCH_SIZE; @@ -276,6 +278,16 @@ protected void doInfer( action.execute(input, timeout, listener); } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java index fe0edb851902..881e7d36f2a2 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/mistral/MistralService.java @@ -36,6 +36,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -58,6 +59,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty; import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull; import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap; +import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; import static org.elasticsearch.xpack.inference.services.mistral.MistralConstants.MODEL_FIELD; public class MistralService extends SenderService { @@ -88,6 +90,16 @@ protected void doInfer( } } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + throwUnsupportedUnifiedCompletionOperation(NAME); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java index 20ff1c617d21..7b51b068708c 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/OpenAiService.java @@ -32,10 +32,13 @@ import org.elasticsearch.rest.RestStatus; import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder; import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker; +import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; +import org.elasticsearch.xpack.inference.external.http.sender.OpenAiUnifiedCompletionRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.ConfigurationParseContext; import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; @@ -53,6 +56,8 @@ import java.util.Map; import java.util.Set; +import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; +import static org.elasticsearch.xpack.inference.external.action.openai.OpenAiActionCreator.COMPLETION_ERROR_PREFIX; import static org.elasticsearch.xpack.inference.services.ServiceFields.MODEL_ID; import static org.elasticsearch.xpack.inference.services.ServiceFields.URL; import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException; @@ -257,6 +262,28 @@ public void doInfer( action.execute(inputs, timeout, listener); } + @Override + public void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) { + if (model instanceof OpenAiChatCompletionModel == false) { + listener.onFailure(createInvalidModelException(model)); + return; + } + + OpenAiChatCompletionModel openAiModel = (OpenAiChatCompletionModel) model; + + var overriddenModel = OpenAiChatCompletionModel.of(openAiModel, inputs.getRequest()); + var requestCreator = OpenAiUnifiedCompletionRequestManager.of(overriddenModel, getServiceComponents().threadPool()); + var errorMessage = constructFailedToSendRequestMessage(overriddenModel.getServiceSettings().uri(), COMPLETION_ERROR_PREFIX); + var action = new SenderExecutableAction(getSender(), requestCreator, errorMessage); + + action.execute(inputs, timeout, listener); + } + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java index e721cd2955cf..7d79d64b3a77 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModel.java @@ -13,6 +13,7 @@ import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SettingsConfiguration; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.inference.configuration.SettingsConfigurationDisplayType; import org.elasticsearch.inference.configuration.SettingsConfigurationFieldType; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; @@ -24,6 +25,7 @@ import java.util.Collections; import java.util.HashMap; import java.util.Map; +import java.util.Objects; import static org.elasticsearch.xpack.inference.services.openai.OpenAiServiceFields.USER; @@ -38,6 +40,26 @@ public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, Map< return new OpenAiChatCompletionModel(model, OpenAiChatCompletionTaskSettings.of(model.getTaskSettings(), requestTaskSettings)); } + public static OpenAiChatCompletionModel of(OpenAiChatCompletionModel model, UnifiedCompletionRequest request) { + var originalModelServiceSettings = model.getServiceSettings(); + var overriddenServiceSettings = new OpenAiChatCompletionServiceSettings( + Objects.requireNonNullElse(request.model(), originalModelServiceSettings.modelId()), + originalModelServiceSettings.uri(), + originalModelServiceSettings.organizationId(), + originalModelServiceSettings.maxInputTokens(), + originalModelServiceSettings.rateLimitSettings() + ); + + return new OpenAiChatCompletionModel( + model.getInferenceEntityId(), + model.getTaskType(), + model.getConfigurations().getService(), + overriddenServiceSettings, + model.getTaskSettings(), + model.getSecretSettings() + ); + } + public OpenAiChatCompletionModel( String inferenceEntityId, TaskType taskType, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java index 8029d8579bab..7ef7f85d71a6 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionRequestTaskSettings.java @@ -48,5 +48,4 @@ public static OpenAiChatCompletionRequestTaskSettings fromMap(Map TaskType.fromStringOrStatusException(null)); + assertThat(exception.getMessage(), Matchers.is("Task type must not be null")); + + exception = expectThrows(ElasticsearchStatusException.class, () -> TaskType.fromStringOrStatusException("blah")); + assertThat(exception.getMessage(), Matchers.is("Unknown task_type [blah]")); + + assertThat(TaskType.fromStringOrStatusException("any"), Matchers.is(TaskType.ANY)); + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java index 5abb9000f4d0..9395ae222e9b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/Utils.java @@ -19,6 +19,7 @@ import org.elasticsearch.inference.ModelConfigurations; import org.elasticsearch.inference.ModelSecrets; import org.elasticsearch.inference.SimilarityMeasure; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.threadpool.ScalingExecutorBuilder; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.common.Truncator; @@ -160,9 +161,11 @@ public static Model getInvalidModel(String inferenceEntityId, String serviceName var mockConfigs = mock(ModelConfigurations.class); when(mockConfigs.getInferenceEntityId()).thenReturn(inferenceEntityId); when(mockConfigs.getService()).thenReturn(serviceName); + when(mockConfigs.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); var mockModel = mock(Model.class); when(mockModel.getConfigurations()).thenReturn(mockConfigs); + when(mockModel.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); return mockModel; } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java new file mode 100644 index 000000000000..47f3a0e0b57a --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/BaseTransportInferenceActionTestCase.java @@ -0,0 +1,364 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.common.xcontent.ChunkedToXContent; +import org.elasticsearch.inference.InferenceService; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.inference.Model; +import org.elasticsearch.inference.ModelConfigurations; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnparsedModel; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.BaseInferenceActionRequest; +import org.elasticsearch.xpack.core.inference.action.InferenceAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; +import org.junit.Before; +import org.mockito.ArgumentCaptor; + +import java.util.Map; +import java.util.Optional; +import java.util.Set; +import java.util.concurrent.Flow; +import java.util.function.Consumer; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.ArgumentMatchers.same; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public abstract class BaseTransportInferenceActionTestCase extends ESTestCase { + private ModelRegistry modelRegistry; + private StreamingTaskManager streamingTaskManager; + private BaseTransportInferenceAction action; + + protected static final String serviceId = "serviceId"; + protected static final TaskType taskType = TaskType.COMPLETION; + protected static final String inferenceId = "inferenceEntityId"; + protected InferenceServiceRegistry serviceRegistry; + protected InferenceStats inferenceStats; + + @Before + public void setUp() throws Exception { + super.setUp(); + TransportService transportService = mock(); + ActionFilters actionFilters = mock(); + modelRegistry = mock(); + serviceRegistry = mock(); + inferenceStats = new InferenceStats(mock(), mock()); + streamingTaskManager = mock(); + action = createAction(transportService, actionFilters, modelRegistry, serviceRegistry, inferenceStats, streamingTaskManager); + } + + protected abstract BaseTransportInferenceAction createAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager + ); + + protected abstract Request createRequest(); + + public void testMetricsAfterModelRegistryError() { + var expectedException = new IllegalStateException("hello"); + var expectedError = expectedException.getClass().getSimpleName(); + + doAnswer(ans -> { + ActionListener listener = ans.getArgument(1); + listener.onFailure(expectedException); + return null; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + + var listener = doExecute(taskType); + verify(listener).onFailure(same(expectedException)); + + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), nullValue()); + assertThat(attributes.get("task_type"), nullValue()); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + protected ActionListener doExecute(TaskType taskType) { + return doExecute(taskType, false); + } + + protected ActionListener doExecute(TaskType taskType, boolean stream) { + Request request = createRequest(); + when(request.getInferenceEntityId()).thenReturn(inferenceId); + when(request.getTaskType()).thenReturn(taskType); + when(request.isStreaming()).thenReturn(stream); + ActionListener listener = mock(); + action.doExecute(mock(), request, listener); + return listener; + } + + public void testMetricsAfterMissingService() { + mockModelRegistry(taskType); + + when(serviceRegistry.getService(any())).thenReturn(Optional.empty()); + + var listener = doExecute(taskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat(e.getMessage(), is("Unknown service [" + serviceId + "] for model [" + inferenceId + "]. ")); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + protected void mockModelRegistry(TaskType expectedTaskType) { + var unparsedModel = new UnparsedModel(inferenceId, expectedTaskType, serviceId, Map.of(), Map.of()); + doAnswer(ans -> { + ActionListener listener = ans.getArgument(1); + listener.onResponse(unparsedModel); + return null; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + } + + public void testMetricsAfterUnknownTaskType() { + var modelTaskType = TaskType.RERANK; + var requestTaskType = TaskType.SPARSE_EMBEDDING; + mockModelRegistry(modelTaskType); + when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); + + var listener = doExecute(requestTaskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is( + "Incompatible task_type, the requested type [" + + requestTaskType + + "] does not match the model type [" + + modelTaskType + + "]" + ) + ); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(modelTaskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + public void testMetricsAfterInferError() { + var expectedException = new IllegalStateException("hello"); + var expectedError = expectedException.getClass().getSimpleName(); + mockService(listener -> listener.onFailure(expectedException)); + + var listener = doExecute(taskType); + + verify(listener).onFailure(same(expectedException)); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testMetricsAfterStreamUnsupported() { + var expectedStatus = RestStatus.METHOD_NOT_ALLOWED; + var expectedError = String.valueOf(expectedStatus.getStatus()); + mockService(l -> {}); + + var listener = doExecute(taskType, true); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + var ese = (ElasticsearchStatusException) e; + assertThat(ese.getMessage(), is("Streaming is not allowed for service [" + serviceId + "].")); + assertThat(ese.status(), is(expectedStatus)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(expectedStatus.getStatus())); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testMetricsAfterInferSuccess() { + mockService(listener -> listener.onResponse(mock())); + + var listener = doExecute(taskType); + + verify(listener).onResponse(any()); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } + + public void testMetricsAfterStreamInferSuccess() { + mockStreamResponse(Flow.Subscriber::onComplete); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } + + public void testMetricsAfterStreamInferFailure() { + var expectedException = new IllegalStateException("hello"); + var expectedError = expectedException.getClass().getSimpleName(); + mockStreamResponse(subscriber -> { + subscriber.subscribe(mock()); + subscriber.onError(expectedException); + }); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), nullValue()); + assertThat(attributes.get("error.type"), is(expectedError)); + })); + } + + public void testMetricsAfterStreamCancel() { + var response = mockStreamResponse(s -> s.onSubscribe(mock())); + response.subscribe(new Flow.Subscriber<>() { + @Override + public void onSubscribe(Flow.Subscription subscription) { + subscription.cancel(); + } + + @Override + public void onNext(ChunkedToXContent item) { + + } + + @Override + public void onError(Throwable throwable) { + + } + + @Override + public void onComplete() { + + } + }); + + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } + + protected Flow.Publisher mockStreamResponse(Consumer> action) { + mockService(true, Set.of(), listener -> { + Flow.Processor taskProcessor = mock(); + doAnswer(innerAns -> { + action.accept(innerAns.getArgument(0)); + return null; + }).when(taskProcessor).subscribe(any()); + when(streamingTaskManager.create(any(), any())).thenReturn(taskProcessor); + var inferenceServiceResults = mock(InferenceServiceResults.class); + when(inferenceServiceResults.publisher()).thenReturn(mock()); + listener.onResponse(inferenceServiceResults); + }); + + var listener = doExecute(taskType, true); + var captor = ArgumentCaptor.forClass(InferenceAction.Response.class); + verify(listener).onResponse(captor.capture()); + assertTrue(captor.getValue().isStreaming()); + assertNotNull(captor.getValue().publisher()); + return captor.getValue().publisher(); + } + + protected void mockService(Consumer> listenerAction) { + mockService(false, Set.of(), listenerAction); + } + + protected void mockService( + boolean stream, + Set supportedStreamingTasks, + Consumer> listenerAction + ) { + InferenceService service = mock(); + Model model = mockModel(); + when(service.parsePersistedConfigWithSecrets(any(), any(), any(), any())).thenReturn(model); + when(service.name()).thenReturn(serviceId); + + when(service.canStream(any())).thenReturn(stream); + when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks); + doAnswer(ans -> { + listenerAction.accept(ans.getArgument(7)); + return null; + }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); + doAnswer(ans -> { + listenerAction.accept(ans.getArgument(3)); + return null; + }).when(service).unifiedCompletionInfer(any(), any(), any(), any()); + mockModelAndServiceRegistry(service); + } + + protected Model mockModel() { + Model model = mock(); + ModelConfigurations modelConfigurations = mock(); + when(modelConfigurations.getService()).thenReturn(serviceId); + when(model.getConfigurations()).thenReturn(modelConfigurations); + when(model.getTaskType()).thenReturn(taskType); + when(model.getServiceSettings()).thenReturn(mock()); + return model; + } + + protected void mockModelAndServiceRegistry(InferenceService service) { + var unparsedModel = new UnparsedModel(inferenceId, taskType, serviceId, Map.of(), Map.of()); + doAnswer(ans -> { + ActionListener listener = ans.getArgument(1); + listener.onResponse(unparsedModel); + return null; + }).when(modelRegistry).getModelWithSecrets(any(), any()); + + when(serviceRegistry.getService(any())).thenReturn(Optional.of(service)); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java index 0ed9cbf56b3f..e54175cb2700 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportInferenceActionTests.java @@ -7,66 +7,28 @@ package org.elasticsearch.xpack.inference.action; -import org.elasticsearch.ElasticsearchStatusException; -import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.ActionFilters; -import org.elasticsearch.common.xcontent.ChunkedToXContent; -import org.elasticsearch.inference.InferenceService; import org.elasticsearch.inference.InferenceServiceRegistry; -import org.elasticsearch.inference.InferenceServiceResults; -import org.elasticsearch.inference.Model; -import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.TaskType; -import org.elasticsearch.inference.UnparsedModel; -import org.elasticsearch.rest.RestStatus; -import org.elasticsearch.test.ESTestCase; import org.elasticsearch.transport.TransportService; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; import org.elasticsearch.xpack.inference.registry.ModelRegistry; import org.elasticsearch.xpack.inference.telemetry.InferenceStats; -import org.junit.Before; -import org.mockito.ArgumentCaptor; -import java.util.Map; -import java.util.Optional; -import java.util.Set; -import java.util.concurrent.Flow; -import java.util.function.Consumer; - -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.isA; -import static org.hamcrest.Matchers.nullValue; -import static org.mockito.ArgumentMatchers.any; -import static org.mockito.ArgumentMatchers.anyBoolean; -import static org.mockito.ArgumentMatchers.anyLong; -import static org.mockito.ArgumentMatchers.assertArg; -import static org.mockito.ArgumentMatchers.same; -import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -public class TransportInferenceActionTests extends ESTestCase { - private static final String serviceId = "serviceId"; - private static final TaskType taskType = TaskType.COMPLETION; - private static final String inferenceId = "inferenceEntityId"; - private ModelRegistry modelRegistry; - private InferenceServiceRegistry serviceRegistry; - private InferenceStats inferenceStats; - private StreamingTaskManager streamingTaskManager; - private TransportInferenceAction action; +public class TransportInferenceActionTests extends BaseTransportInferenceActionTestCase { - @Before - public void setUp() throws Exception { - super.setUp(); - TransportService transportService = mock(); - ActionFilters actionFilters = mock(); - modelRegistry = mock(); - serviceRegistry = mock(); - inferenceStats = new InferenceStats(mock(), mock()); - streamingTaskManager = mock(); - action = new TransportInferenceAction( + @Override + protected BaseTransportInferenceAction createAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager + ) { + return new TransportInferenceAction( transportService, actionFilters, modelRegistry, @@ -76,279 +38,8 @@ public void setUp() throws Exception { ); } - public void testMetricsAfterModelRegistryError() { - var expectedException = new IllegalStateException("hello"); - var expectedError = expectedException.getClass().getSimpleName(); - - doAnswer(ans -> { - ActionListener listener = ans.getArgument(1); - listener.onFailure(expectedException); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - - var listener = doExecute(taskType); - verify(listener).onFailure(same(expectedException)); - - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), nullValue()); - assertThat(attributes.get("task_type"), nullValue()); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), nullValue()); - assertThat(attributes.get("error.type"), is(expectedError)); - })); - } - - private ActionListener doExecute(TaskType taskType) { - return doExecute(taskType, false); - } - - private ActionListener doExecute(TaskType taskType, boolean stream) { - InferenceAction.Request request = mock(); - when(request.getInferenceEntityId()).thenReturn(inferenceId); - when(request.getTaskType()).thenReturn(taskType); - when(request.isStreaming()).thenReturn(stream); - ActionListener listener = mock(); - action.doExecute(mock(), request, listener); - return listener; - } - - public void testMetricsAfterMissingService() { - mockModelRegistry(taskType); - - when(serviceRegistry.getService(any())).thenReturn(Optional.empty()); - - var listener = doExecute(taskType); - - verify(listener).onFailure(assertArg(e -> { - assertThat(e, isA(ElasticsearchStatusException.class)); - assertThat(e.getMessage(), is("Unknown service [" + serviceId + "] for model [" + inferenceId + "]. ")); - assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); - })); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); - assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); - })); - } - - private void mockModelRegistry(TaskType expectedTaskType) { - var unparsedModel = new UnparsedModel(inferenceId, expectedTaskType, serviceId, Map.of(), Map.of()); - doAnswer(ans -> { - ActionListener listener = ans.getArgument(1); - listener.onResponse(unparsedModel); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - } - - public void testMetricsAfterUnknownTaskType() { - var modelTaskType = TaskType.RERANK; - var requestTaskType = TaskType.SPARSE_EMBEDDING; - mockModelRegistry(modelTaskType); - when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); - - var listener = doExecute(requestTaskType); - - verify(listener).onFailure(assertArg(e -> { - assertThat(e, isA(ElasticsearchStatusException.class)); - assertThat( - e.getMessage(), - is( - "Incompatible task_type, the requested type [" - + requestTaskType - + "] does not match the model type [" - + modelTaskType - + "]" - ) - ); - assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); - })); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(modelTaskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); - assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); - })); - } - - public void testMetricsAfterInferError() { - var expectedException = new IllegalStateException("hello"); - var expectedError = expectedException.getClass().getSimpleName(); - mockService(listener -> listener.onFailure(expectedException)); - - var listener = doExecute(taskType); - - verify(listener).onFailure(same(expectedException)); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), nullValue()); - assertThat(attributes.get("error.type"), is(expectedError)); - })); - } - - public void testMetricsAfterStreamUnsupported() { - var expectedStatus = RestStatus.METHOD_NOT_ALLOWED; - var expectedError = String.valueOf(expectedStatus.getStatus()); - mockService(l -> {}); - - var listener = doExecute(taskType, true); - - verify(listener).onFailure(assertArg(e -> { - assertThat(e, isA(ElasticsearchStatusException.class)); - var ese = (ElasticsearchStatusException) e; - assertThat(ese.getMessage(), is("Streaming is not allowed for service [" + serviceId + "].")); - assertThat(ese.status(), is(expectedStatus)); - })); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(expectedStatus.getStatus())); - assertThat(attributes.get("error.type"), is(expectedError)); - })); - } - - public void testMetricsAfterInferSuccess() { - mockService(listener -> listener.onResponse(mock())); - - var listener = doExecute(taskType); - - verify(listener).onResponse(any()); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(200)); - assertThat(attributes.get("error.type"), nullValue()); - })); - } - - public void testMetricsAfterStreamInferSuccess() { - mockStreamResponse(Flow.Subscriber::onComplete); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(200)); - assertThat(attributes.get("error.type"), nullValue()); - })); - } - - public void testMetricsAfterStreamInferFailure() { - var expectedException = new IllegalStateException("hello"); - var expectedError = expectedException.getClass().getSimpleName(); - mockStreamResponse(subscriber -> { - subscriber.subscribe(mock()); - subscriber.onError(expectedException); - }); - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), nullValue()); - assertThat(attributes.get("error.type"), is(expectedError)); - })); - } - - public void testMetricsAfterStreamCancel() { - var response = mockStreamResponse(s -> s.onSubscribe(mock())); - response.subscribe(new Flow.Subscriber<>() { - @Override - public void onSubscribe(Flow.Subscription subscription) { - subscription.cancel(); - } - - @Override - public void onNext(ChunkedToXContent item) { - - } - - @Override - public void onError(Throwable throwable) { - - } - - @Override - public void onComplete() { - - } - }); - - verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { - assertThat(attributes.get("service"), is(serviceId)); - assertThat(attributes.get("task_type"), is(taskType.toString())); - assertThat(attributes.get("model_id"), nullValue()); - assertThat(attributes.get("status_code"), is(200)); - assertThat(attributes.get("error.type"), nullValue()); - })); - } - - private Flow.Publisher mockStreamResponse(Consumer> action) { - mockService(true, Set.of(), listener -> { - Flow.Processor taskProcessor = mock(); - doAnswer(innerAns -> { - action.accept(innerAns.getArgument(0)); - return null; - }).when(taskProcessor).subscribe(any()); - when(streamingTaskManager.create(any(), any())).thenReturn(taskProcessor); - var inferenceServiceResults = mock(InferenceServiceResults.class); - when(inferenceServiceResults.publisher()).thenReturn(mock()); - listener.onResponse(inferenceServiceResults); - }); - - var listener = doExecute(taskType, true); - var captor = ArgumentCaptor.forClass(InferenceAction.Response.class); - verify(listener).onResponse(captor.capture()); - assertTrue(captor.getValue().isStreaming()); - assertNotNull(captor.getValue().publisher()); - return captor.getValue().publisher(); - } - - private void mockService(Consumer> listenerAction) { - mockService(false, Set.of(), listenerAction); - } - - private void mockService( - boolean stream, - Set supportedStreamingTasks, - Consumer> listenerAction - ) { - InferenceService service = mock(); - Model model = mockModel(); - when(service.parsePersistedConfigWithSecrets(any(), any(), any(), any())).thenReturn(model); - when(service.name()).thenReturn(serviceId); - - when(service.canStream(any())).thenReturn(stream); - when(service.supportedStreamingTasks()).thenReturn(supportedStreamingTasks); - doAnswer(ans -> { - listenerAction.accept(ans.getArgument(7)); - return null; - }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); - mockModelAndServiceRegistry(service); - } - - private Model mockModel() { - Model model = mock(); - ModelConfigurations modelConfigurations = mock(); - when(modelConfigurations.getService()).thenReturn(serviceId); - when(model.getConfigurations()).thenReturn(modelConfigurations); - when(model.getTaskType()).thenReturn(taskType); - when(model.getServiceSettings()).thenReturn(mock()); - return model; - } - - private void mockModelAndServiceRegistry(InferenceService service) { - var unparsedModel = new UnparsedModel(inferenceId, taskType, serviceId, Map.of(), Map.of()); - doAnswer(ans -> { - ActionListener listener = ans.getArgument(1); - listener.onResponse(unparsedModel); - return null; - }).when(modelRegistry).getModelWithSecrets(any(), any()); - - when(serviceRegistry.getService(any())).thenReturn(Optional.of(service)); + @Override + protected InferenceAction.Request createRequest() { + return mock(); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java new file mode 100644 index 000000000000..4c943599ce52 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/TransportUnifiedCompletionActionTests.java @@ -0,0 +1,124 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.action; + +import org.elasticsearch.ElasticsearchStatusException; +import org.elasticsearch.action.support.ActionFilters; +import org.elasticsearch.inference.InferenceServiceRegistry; +import org.elasticsearch.inference.TaskType; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.transport.TransportService; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.elasticsearch.xpack.inference.action.task.StreamingTaskManager; +import org.elasticsearch.xpack.inference.registry.ModelRegistry; +import org.elasticsearch.xpack.inference.telemetry.InferenceStats; + +import java.util.Optional; + +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.isA; +import static org.hamcrest.Matchers.nullValue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.assertArg; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class TransportUnifiedCompletionActionTests extends BaseTransportInferenceActionTestCase { + + @Override + protected BaseTransportInferenceAction createAction( + TransportService transportService, + ActionFilters actionFilters, + ModelRegistry modelRegistry, + InferenceServiceRegistry serviceRegistry, + InferenceStats inferenceStats, + StreamingTaskManager streamingTaskManager + ) { + return new TransportUnifiedCompletionInferenceAction( + transportService, + actionFilters, + modelRegistry, + serviceRegistry, + inferenceStats, + streamingTaskManager + ); + } + + @Override + protected UnifiedCompletionAction.Request createRequest() { + return mock(); + } + + public void testThrows_IncompatibleTaskTypeException_WhenUsingATextEmbeddingInferenceEndpoint() { + var modelTaskType = TaskType.TEXT_EMBEDDING; + var requestTaskType = TaskType.TEXT_EMBEDDING; + mockModelRegistry(modelTaskType); + when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); + + var listener = doExecute(requestTaskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Incompatible task_type for unified API, the requested type [" + requestTaskType + "] must be one of [completion]") + ); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(modelTaskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + public void testThrows_IncompatibleTaskTypeException_WhenUsingRequestIsAny_ModelIsTextEmbedding() { + var modelTaskType = TaskType.ANY; + var requestTaskType = TaskType.TEXT_EMBEDDING; + mockModelRegistry(modelTaskType); + when(serviceRegistry.getService(any())).thenReturn(Optional.of(mock())); + + var listener = doExecute(requestTaskType); + + verify(listener).onFailure(assertArg(e -> { + assertThat(e, isA(ElasticsearchStatusException.class)); + assertThat( + e.getMessage(), + is("Incompatible task_type for unified API, the requested type [" + requestTaskType + "] must be one of [completion]") + ); + assertThat(((ElasticsearchStatusException) e).status(), is(RestStatus.BAD_REQUEST)); + })); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(modelTaskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(RestStatus.BAD_REQUEST.getStatus())); + assertThat(attributes.get("error.type"), is(String.valueOf(RestStatus.BAD_REQUEST.getStatus()))); + })); + } + + public void testMetricsAfterUnifiedInferSuccess_WithRequestTaskTypeAny() { + mockModelRegistry(TaskType.COMPLETION); + mockService(listener -> listener.onResponse(mock())); + + var listener = doExecute(TaskType.ANY); + + verify(listener).onResponse(any()); + verify(inferenceStats.inferenceDuration()).record(anyLong(), assertArg(attributes -> { + assertThat(attributes.get("service"), is(serviceId)); + assertThat(attributes.get("task_type"), is(taskType.toString())); + assertThat(attributes.get("model_id"), nullValue()); + assertThat(attributes.get("status_code"), is(200)); + assertThat(attributes.get("error.type"), nullValue()); + })); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java index d4ab9b1f1e19..9e7c58b0ca79 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/SingleInputSenderExecutableActionTests.java @@ -61,25 +61,11 @@ public void testOneInputIsValid() { assertTrue("Test failed to call listener.", testRan.get()); } - public void testInvalidInputType() { - var badInput = mock(InferenceInputs.class); - var actualException = new AtomicReference(); - - executableAction.execute( - badInput, - mock(TimeValue.class), - ActionListener.wrap(shouldNotSucceed -> fail("Test failed."), actualException::set) - ); - - assertThat(actualException.get(), notNullValue()); - assertThat(actualException.get().getMessage(), is("Invalid inference input type")); - assertThat(actualException.get(), instanceOf(ElasticsearchStatusException.class)); - assertThat(((ElasticsearchStatusException) actualException.get()).status(), is(RestStatus.INTERNAL_SERVER_ERROR)); - } - public void testMoreThanOneInput() { var badInput = mock(DocumentsOnlyInput.class); - when(badInput.getInputs()).thenReturn(List.of("one", "two")); + var input = List.of("one", "two"); + when(badInput.getInputs()).thenReturn(input); + when(badInput.inputSize()).thenReturn(input.size()); var actualException = new AtomicReference(); executableAction.execute( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java index 87d3a82b4aae..e7543aa6ba9e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/amazonbedrock/AmazonBedrockActionCreatorTests.java @@ -17,6 +17,7 @@ import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.amazonbedrock.AmazonBedrockMockRequestSender; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.services.ServiceComponentsTests; import org.elasticsearch.xpack.inference.services.amazonbedrock.AmazonBedrockProvider; @@ -130,7 +131,7 @@ public void testCompletionRequestAction() throws IOException { ); var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test input string")))); @@ -163,7 +164,7 @@ public void testChatCompletionRequestAction_HandlesException() throws IOExceptio ); var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); assertThat(sender.sendCount(), is(1)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java index a3114300c5dd..f0de37ceaaf9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicActionCreatorTests.java @@ -20,7 +20,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.request.anthropic.AnthropicRequestUtils; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -49,6 +49,7 @@ import static org.mockito.Mockito.mock; public class AnthropicActionCreatorTests extends ESTestCase { + private static final TimeValue TIMEOUT = new TimeValue(30, TimeUnit.SECONDS); private final MockWebServer webServer = new MockWebServer(); private ThreadPool threadPool; @@ -103,7 +104,7 @@ public void testCreate_ChatCompletionModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -168,7 +169,7 @@ public void testCreate_ChatCompletionModel_FailsFromInvalidResponseFormat() thro var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java index fca2e316af17..2065a726b758 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/anthropic/AnthropicChatCompletionActionTests.java @@ -27,7 +27,7 @@ import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.AnthropicCompletionRequestManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -113,7 +113,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -149,7 +149,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -170,7 +170,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -187,7 +187,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -229,7 +229,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var action = createAction(getUrl(webServer), "secret", "model", 1, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java index 8792234102a9..210fab457de1 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureaistudio/AzureAiStudioActionAndCreatorTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; @@ -160,7 +161,7 @@ public void testChatCompletionRequestAction() throws IOException { var action = creator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java index 45a2fb0954c7..7e1e3e55caed 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiActionCreatorTests.java @@ -24,6 +24,7 @@ import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.ChatCompletionResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; @@ -475,7 +476,7 @@ public void testInfer_AzureOpenAiCompletion_WithOverriddenUser() throws IOExcept var action = actionCreator.create(model, taskSettingsWithUserOverride); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -531,7 +532,7 @@ public void testInfer_AzureOpenAiCompletionModel_WithoutUser() throws IOExceptio var action = actionCreator.create(model, requestTaskSettingsWithoutUser); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -589,7 +590,7 @@ public void testInfer_AzureOpenAiCompletionModel_FailsFromInvalidResponseFormat( var action = actionCreator.create(model, requestTaskSettingsWithoutUser); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java index 4c7683c88281..dca12dfda9c9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/azureopenai/AzureOpenAiCompletionActionTests.java @@ -26,7 +26,7 @@ import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.AzureOpenAiCompletionRequestManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.request.azureopenai.AzureOpenAiUtils; @@ -111,7 +111,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction("resource", "deployment", "apiversion", user, apiKey, sender, "id"); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of(completionInput)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -142,7 +142,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -163,7 +163,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -177,7 +177,7 @@ public void testExecute_ThrowsException() { var action = createAction("resource", "deployment", "apiVersion", "user", "apikey", sender, "id"); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java index 9ec34e7d8e5c..3a512de25a39 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereActionCreatorTests.java @@ -20,6 +20,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -197,7 +198,7 @@ public void testCreate_CohereCompletionModel_WithModelSpecified() throws IOExcep var action = actionCreator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -257,7 +258,7 @@ public void testCreate_CohereCompletionModel_WithoutModelSpecified() throws IOEx var action = actionCreator.create(model, Map.of()); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java index ba839e0d7c5e..c5871adb3486 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/cohere/CohereCompletionActionTests.java @@ -26,8 +26,8 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.CohereCompletionRequestManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.external.request.cohere.CohereUtils; @@ -120,7 +120,7 @@ public void testExecute_ReturnsSuccessfulResponse_WithModelSpecified() throws IO var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -181,7 +181,7 @@ public void testExecute_ReturnsSuccessfulResponse_WithoutModelSpecified() throws var action = createAction(getUrl(webServer), "secret", null, sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -214,7 +214,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -235,7 +235,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -256,7 +256,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(null, "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -270,7 +270,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -284,7 +284,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { var action = createAction(null, "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -334,7 +334,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java index 72b5ffa45a0d..ff17bbf66e02 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/googleaistudio/GoogleAiStudioCompletionActionTests.java @@ -25,7 +25,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.GoogleAiStudioCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; @@ -128,7 +128,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("input")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("input")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -159,7 +159,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -180,7 +180,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -197,7 +197,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -260,7 +260,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var action = createAction(getUrl(webServer), "secret", "model", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java index b6d7eb673b7f..fe076eb721ea 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiActionCreatorTests.java @@ -21,6 +21,7 @@ import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -330,7 +331,7 @@ public void testCreate_OpenAiChatCompletionModel() throws IOException { var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -345,11 +346,12 @@ public void testCreate_OpenAiChatCompletionModel() throws IOException { assertThat(request.getHeader(ORGANIZATION_HEADER), equalTo("org")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("overridden_user")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } @@ -393,7 +395,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOExceptio var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -408,10 +410,11 @@ public void testCreate_OpenAiChatCompletionModel_WithoutUser() throws IOExceptio assertThat(request.getHeader(ORGANIZATION_HEADER), equalTo("org")); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(3)); + assertThat(requestMap.size(), is(4)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } @@ -455,7 +458,7 @@ public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IO var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -470,11 +473,12 @@ public void testCreate_OpenAiChatCompletionModel_WithoutOrganization() throws IO assertNull(request.getHeader(ORGANIZATION_HEADER)); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("overridden_user")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } @@ -523,7 +527,7 @@ public void testCreate_OpenAiChatCompletionModel_FailsFromInvalidResponseFormat( var action = actionCreator.create(model, overriddenTaskSettings); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); assertThat( @@ -542,11 +546,12 @@ public void testCreate_OpenAiChatCompletionModel_FailsFromInvalidResponseFormat( assertNull(webServer.requests().get(0).getHeader(ORGANIZATION_HEADER)); var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("overridden_user")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java index d84b2b5bb324..ba74d2ab42c2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/openai/OpenAiChatCompletionActionTests.java @@ -27,7 +27,7 @@ import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SingleInputSenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.OpenAiCompletionRequestManager; @@ -119,7 +119,7 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var result = listener.actionGet(TIMEOUT); @@ -134,11 +134,12 @@ public void testExecute_ReturnsSuccessfulResponse() throws IOException { assertThat(request.getHeader(ORGANIZATION_HEADER), equalTo("org")); var requestMap = entityAsMap(request.getBody()); - assertThat(requestMap.size(), is(4)); + assertThat(requestMap.size(), is(5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("user"), is("user")); assertThat(requestMap.get("n"), is(1)); + assertThat(requestMap.get("stream"), is(false)); } } @@ -159,7 +160,7 @@ public void testExecute_ThrowsElasticsearchException() { var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -180,7 +181,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -201,7 +202,7 @@ public void testExecute_ThrowsElasticsearchException_WhenSenderOnFailureIsCalled var action = createAction(null, "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -215,7 +216,7 @@ public void testExecute_ThrowsException() { var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -229,7 +230,7 @@ public void testExecute_ThrowsExceptionWithNullUrl() { var action = createAction(null, "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); @@ -273,7 +274,7 @@ public void testExecute_ThrowsException_WhenInputIsGreaterThanOne() throws IOExc var action = createAction(getUrl(webServer), "org", "secret", "model", "user", sender); PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); + action.execute(new ChatCompletionInput(List.of("abc", "def")), InferenceAction.Request.DEFAULT_TIMEOUT, listener); var thrownException = expectThrows(ElasticsearchStatusException.class, () -> listener.actionGet(TIMEOUT)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java index e68beaf4c1eb..929aefeeef6b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockMockRequestSender.java @@ -12,6 +12,7 @@ import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.RequestManager; @@ -67,8 +68,15 @@ public void send( ActionListener listener ) { sendCounter++; - var docsInput = (DocumentsOnlyInput) inferenceInputs; - inputs.add(docsInput.getInputs()); + if (inferenceInputs instanceof DocumentsOnlyInput docsInput) { + inputs.add(docsInput.getInputs()); + } else if (inferenceInputs instanceof ChatCompletionInput chatCompletionInput) { + inputs.add(chatCompletionInput.getInputs()); + } else { + throw new IllegalArgumentException( + "Invalid inference inputs received in mock sender: " + inferenceInputs.getClass().getSimpleName() + ); + } if (results.isEmpty()) { listener.onFailure(new ElasticsearchException("No results found")); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java index 7fa8a09d5bf1..a8f37aedcece 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/amazonbedrock/AmazonBedrockRequestSenderTests.java @@ -15,6 +15,7 @@ import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockChatCompletionRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.AmazonBedrockEmbeddingsRequestManager; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -107,7 +108,7 @@ public void testCreateSender_SendsCompletionRequestAndReceivesResponse() throws PlainActionFuture listener = new PlainActionFuture<>(); var requestManager = new AmazonBedrockChatCompletionRequestManager(model, threadPool, new TimeValue(30, TimeUnit.SECONDS)); - sender.send(requestManager, new DocumentsOnlyInput(List.of("abc")), null, listener); + sender.send(requestManager, new ChatCompletionInput(List.of("abc")), null, listener); var result = listener.actionGet(TIMEOUT); assertThat(result.asMap(), is(buildExpectationCompletion(List.of("test response text")))); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java new file mode 100644 index 000000000000..f0da67a98237 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/InferenceInputsTests.java @@ -0,0 +1,40 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; + +import java.util.List; + +public class InferenceInputsTests extends ESTestCase { + public void testCastToSucceeds() { + InferenceInputs inputs = new DocumentsOnlyInput(List.of(), false); + assertThat(inputs.castTo(DocumentsOnlyInput.class), Matchers.instanceOf(DocumentsOnlyInput.class)); + + var emptyRequest = new UnifiedCompletionRequest(List.of(), null, null, null, null, null, null, null); + assertThat(new UnifiedChatInput(emptyRequest, false).castTo(UnifiedChatInput.class), Matchers.instanceOf(UnifiedChatInput.class)); + assertThat( + new QueryAndDocsInputs("hello", List.of(), false).castTo(QueryAndDocsInputs.class), + Matchers.instanceOf(QueryAndDocsInputs.class) + ); + } + + public void testCastToFails() { + InferenceInputs inputs = new DocumentsOnlyInput(List.of(), false); + var exception = expectThrows(IllegalArgumentException.class, () -> inputs.castTo(QueryAndDocsInputs.class)); + assertThat( + exception.getMessage(), + Matchers.containsString( + Strings.format("Unable to convert inference inputs type: [%s] to [%s]", DocumentsOnlyInput.class, QueryAndDocsInputs.class) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java new file mode 100644 index 000000000000..42e1b18168ae --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/http/sender/UnifiedChatInputTests.java @@ -0,0 +1,46 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.http.sender; + +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.hamcrest.Matchers; + +import java.util.List; + +public class UnifiedChatInputTests extends ESTestCase { + + public void testConvertsStringInputToMessages() { + var a = new UnifiedChatInput(List.of("hello", "awesome"), "a role", true); + + assertThat(a.inputSize(), Matchers.is(2)); + assertThat( + a.getRequest(), + Matchers.is( + UnifiedCompletionRequest.of( + List.of( + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("hello"), + "a role", + null, + null, + null + ), + new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("awesome"), + "a role", + null, + null, + null + ) + ) + ) + ) + ); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java new file mode 100644 index 000000000000..0f127998f9c5 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/openai/OpenAiUnifiedStreamingProcessorTests.java @@ -0,0 +1,383 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.openai; + +import org.elasticsearch.common.Strings; +import org.elasticsearch.common.xcontent.LoggingDeprecationHandler; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.XContentFactory; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xcontent.XContentParserConfiguration; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.results.StreamingUnifiedChatCompletionResults; + +import java.io.IOException; +import java.util.List; + +public class OpenAiUnifiedStreamingProcessorTests extends ESTestCase { + + public void testJsonLiteral() { + String json = """ + { + "id": "example_id", + "choices": [ + { + "delta": { + "content": "example_content", + "refusal": null, + "role": "assistant", + "tool_calls": [ + { + "index": 1, + "id": "tool_call_id", + "function": { + "arguments": "example_arguments", + "name": "example_function_name" + }, + "type": "function" + } + ] + }, + "finish_reason": "stop", + "index": 0 + } + ], + "model": "example_model", + "object": "chat.completion.chunk", + "usage": { + "completion_tokens": 50, + "prompt_tokens": 20, + "total_tokens": 70 + } + } + """; + // Parse the JSON + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertEquals("example_id", chunk.getId()); + assertEquals("example_model", chunk.getModel()); + assertEquals("chat.completion.chunk", chunk.getObject()); + assertNotNull(chunk.getUsage()); + assertEquals(50, chunk.getUsage().completionTokens()); + assertEquals(20, chunk.getUsage().promptTokens()); + assertEquals(70, chunk.getUsage().totalTokens()); + + List choices = chunk.getChoices(); + assertEquals(1, choices.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); + assertEquals("example_content", choice.delta().getContent()); + assertNull(choice.delta().getRefusal()); + assertEquals("assistant", choice.delta().getRole()); + assertEquals("stop", choice.finishReason()); + assertEquals(0, choice.index()); + + List toolCalls = choice.delta().getToolCalls(); + assertEquals(1, toolCalls.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); + assertEquals(1, toolCall.getIndex()); + assertEquals("tool_call_id", toolCall.getId()); + assertEquals("example_function_name", toolCall.getFunction().getName()); + assertEquals("example_arguments", toolCall.getFunction().getArguments()); + assertEquals("function", toolCall.getType()); + } catch (IOException e) { + fail(); + } + } + + public void testJsonLiteralCornerCases() { + String json = """ + { + "id": "example_id", + "choices": [ + { + "delta": { + "content": null, + "refusal": null, + "role": "assistant", + "tool_calls": [] + }, + "finish_reason": null, + "index": 0 + }, + { + "delta": { + "content": "example_content", + "refusal": "example_refusal", + "role": "user", + "tool_calls": [ + { + "index": 1, + "function": { + "name": "example_function_name" + }, + "type": "function" + } + ] + }, + "finish_reason": "stop", + "index": 1 + } + ], + "model": "example_model", + "object": "chat.completion.chunk", + "usage": null + } + """; + // Parse the JSON + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, json)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertEquals("example_id", chunk.getId()); + assertEquals("example_model", chunk.getModel()); + assertEquals("chat.completion.chunk", chunk.getObject()); + assertNull(chunk.getUsage()); + + List choices = chunk.getChoices(); + assertEquals(2, choices.size()); + + // First choice assertions + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice firstChoice = choices.get(0); + assertNull(firstChoice.delta().getContent()); + assertNull(firstChoice.delta().getRefusal()); + assertEquals("assistant", firstChoice.delta().getRole()); + assertTrue(firstChoice.delta().getToolCalls().isEmpty()); + assertNull(firstChoice.finishReason()); + assertEquals(0, firstChoice.index()); + + // Second choice assertions + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice secondChoice = choices.get(1); + assertEquals("example_content", secondChoice.delta().getContent()); + assertEquals("example_refusal", secondChoice.delta().getRefusal()); + assertEquals("user", secondChoice.delta().getRole()); + assertEquals("stop", secondChoice.finishReason()); + assertEquals(1, secondChoice.index()); + + List toolCalls = secondChoice.delta() + .getToolCalls(); + assertEquals(1, toolCalls.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); + assertEquals(1, toolCall.getIndex()); + assertNull(toolCall.getId()); + assertEquals("example_function_name", toolCall.getFunction().getName()); + assertNull(toolCall.getFunction().getArguments()); + assertEquals("function", toolCall.getType()); + } catch (IOException e) { + fail(); + } + } + + public void testOpenAiUnifiedStreamingProcessorParsing() throws IOException { + // Generate random values for the JSON fields + int toolCallIndex = randomIntBetween(0, 10); + String toolCallId = randomAlphaOfLength(5); + String toolCallFunctionName = randomAlphaOfLength(8); + String toolCallFunctionArguments = randomAlphaOfLength(10); + String toolCallType = "function"; + String toolCallJson = createToolCallJson(toolCallIndex, toolCallId, toolCallFunctionName, toolCallFunctionArguments, toolCallType); + + String choiceContent = randomAlphaOfLength(10); + String choiceRole = randomFrom("system", "user", "assistant", "tool"); + String choiceFinishReason = randomFrom("stop", "length", "tool_calls", "content_filter", "function_call", null); + int choiceIndex = randomIntBetween(0, 10); + String choiceJson = createChoiceJson(choiceContent, null, choiceRole, toolCallJson, choiceFinishReason, choiceIndex); + + int usageCompletionTokens = randomIntBetween(1, 100); + int usagePromptTokens = randomIntBetween(1, 100); + int usageTotalTokens = randomIntBetween(1, 200); + String usageJson = createUsageJson(usageCompletionTokens, usagePromptTokens, usageTotalTokens); + + String chatCompletionChunkId = randomAlphaOfLength(10); + String chatCompletionChunkModel = randomAlphaOfLength(5); + String chatCompletionChunkJson = createChatCompletionChunkJson( + chatCompletionChunkId, + choiceJson, + chatCompletionChunkModel, + "chat.completion.chunk", + usageJson + ); + + // Parse the JSON + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, chatCompletionChunkJson)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertEquals(chatCompletionChunkId, chunk.getId()); + assertEquals(chatCompletionChunkModel, chunk.getModel()); + assertEquals("chat.completion.chunk", chunk.getObject()); + assertNotNull(chunk.getUsage()); + assertEquals(usageCompletionTokens, chunk.getUsage().completionTokens()); + assertEquals(usagePromptTokens, chunk.getUsage().promptTokens()); + assertEquals(usageTotalTokens, chunk.getUsage().totalTokens()); + + List choices = chunk.getChoices(); + assertEquals(1, choices.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); + assertEquals(choiceContent, choice.delta().getContent()); + assertNull(choice.delta().getRefusal()); + assertEquals(choiceRole, choice.delta().getRole()); + assertEquals(choiceFinishReason, choice.finishReason()); + assertEquals(choiceIndex, choice.index()); + + List toolCalls = choice.delta().getToolCalls(); + assertEquals(1, toolCalls.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice.Delta.ToolCall toolCall = toolCalls.get(0); + assertEquals(toolCallIndex, toolCall.getIndex()); + assertEquals(toolCallId, toolCall.getId()); + assertEquals(toolCallFunctionName, toolCall.getFunction().getName()); + assertEquals(toolCallFunctionArguments, toolCall.getFunction().getArguments()); + assertEquals(toolCallType, toolCall.getType()); + } + } + + public void testOpenAiUnifiedStreamingProcessorParsingWithNullFields() throws IOException { + // JSON with null fields + int choiceIndex = randomIntBetween(0, 10); + String choiceJson = createChoiceJson(null, null, null, "", null, choiceIndex); + + String chatCompletionChunkId = randomAlphaOfLength(10); + String chatCompletionChunkModel = randomAlphaOfLength(5); + String chatCompletionChunkJson = createChatCompletionChunkJson( + chatCompletionChunkId, + choiceJson, + chatCompletionChunkModel, + "chat.completion.chunk", + null + ); + + // Parse the JSON + XContentParserConfiguration parserConfig = XContentParserConfiguration.EMPTY.withDeprecationHandler( + LoggingDeprecationHandler.INSTANCE + ); + try (XContentParser parser = XContentFactory.xContent(XContentType.JSON).createParser(parserConfig, chatCompletionChunkJson)) { + StreamingUnifiedChatCompletionResults.ChatCompletionChunk chunk = OpenAiUnifiedStreamingProcessor.ChatCompletionChunkParser + .parse(parser); + + // Assertions to verify the parsed object + assertEquals(chatCompletionChunkId, chunk.getId()); + assertEquals(chatCompletionChunkModel, chunk.getModel()); + assertEquals("chat.completion.chunk", chunk.getObject()); + assertNull(chunk.getUsage()); + + List choices = chunk.getChoices(); + assertEquals(1, choices.size()); + StreamingUnifiedChatCompletionResults.ChatCompletionChunk.Choice choice = choices.get(0); + assertNull(choice.delta().getContent()); + assertNull(choice.delta().getRefusal()); + assertNull(choice.delta().getRole()); + assertNull(choice.finishReason()); + assertEquals(choiceIndex, choice.index()); + assertTrue(choice.delta().getToolCalls().isEmpty()); + } + } + + private String createToolCallJson(int index, String id, String functionName, String functionArguments, String type) { + return Strings.format(""" + { + "index": %d, + "id": "%s", + "function": { + "name": "%s", + "arguments": "%s" + }, + "type": "%s" + } + """, index, id, functionName, functionArguments, type); + } + + private String createChoiceJson(String content, String refusal, String role, String toolCallsJson, String finishReason, int index) { + if (role == null) { + return Strings.format( + """ + { + "delta": { + "content": %s, + "refusal": %s, + "tool_calls": [%s] + }, + "finish_reason": %s, + "index": %d + } + """, + content != null ? "\"" + content + "\"" : "null", + refusal != null ? "\"" + refusal + "\"" : "null", + toolCallsJson, + finishReason != null ? "\"" + finishReason + "\"" : "null", + index + ); + } else { + return Strings.format( + """ + { + "delta": { + "content": %s, + "refusal": %s, + "role": %s, + "tool_calls": [%s] + }, + "finish_reason": %s, + "index": %d + } + """, + content != null ? "\"" + content + "\"" : "null", + refusal != null ? "\"" + refusal + "\"" : "null", + role != null ? "\"" + role + "\"" : "null", + toolCallsJson, + finishReason != null ? "\"" + finishReason + "\"" : "null", + index + ); + } + } + + private String createChatCompletionChunkJson(String id, String choicesJson, String model, String object, String usageJson) { + if (usageJson != null) { + return Strings.format(""" + { + "id": "%s", + "choices": [%s], + "model": "%s", + "object": "%s", + "usage": %s + } + """, id, choicesJson, model, object, usageJson); + } else { + return Strings.format(""" + { + "id": "%s", + "choices": [%s], + "model": "%s", + "object": "%s" + } + """, id, choicesJson, model, object); + } + } + + private String createUsageJson(int completionTokens, int promptTokens, int totalTokens) { + return Strings.format(""" + { + "completion_tokens": %d, + "prompt_tokens": %d, + "total_tokens": %d + } + """, completionTokens, promptTokens, totalTokens); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java index 7ffa8940ad6b..065dfee577a8 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/googleaistudio/completion/GoogleAiStudioCompletionRequestTests.java @@ -10,7 +10,7 @@ import org.apache.http.client.methods.HttpPost; import org.elasticsearch.common.Strings; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; +import org.elasticsearch.xpack.inference.external.http.sender.ChatCompletionInput; import org.elasticsearch.xpack.inference.external.request.googleaistudio.GoogleAiStudioCompletionRequest; import org.elasticsearch.xpack.inference.services.googleaistudio.completion.GoogleAiStudioCompletionModelTests; @@ -72,7 +72,7 @@ public void testTruncationInfo_ReturnsNull() { assertNull(request.getTruncationInfo()); } - private static DocumentsOnlyInput listOf(String... input) { - return new DocumentsOnlyInput(List.of(input)); + private static ChatCompletionInput listOf(String... input) { + return new ChatCompletionInput(List.of(input)); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntityTests.java deleted file mode 100644 index 9d5492f9e951..000000000000 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestEntityTests.java +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one - * or more contributor license agreements. Licensed under the Elastic License - * 2.0; you may not use this file except in compliance with the Elastic License - * 2.0. - */ - -package org.elasticsearch.xpack.inference.external.request.openai; - -import org.elasticsearch.common.Strings; -import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.xcontent.XContentBuilder; -import org.elasticsearch.xcontent.XContentFactory; -import org.elasticsearch.xcontent.XContentType; - -import java.io.IOException; -import java.util.List; - -import static org.hamcrest.CoreMatchers.is; - -public class OpenAiChatCompletionRequestEntityTests extends ESTestCase { - - public void testXContent_WritesUserWhenDefined() throws IOException { - var entity = new OpenAiChatCompletionRequestEntity(List.of("abc"), "model", "user", false); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"messages":[{"role":"user","content":"abc"}],"model":"model","n":1,"user":"user"}""")); - - } - - public void testXContent_DoesNotWriteUserWhenItIsNull() throws IOException { - var entity = new OpenAiChatCompletionRequestEntity(List.of("abc"), "model", null, false); - - XContentBuilder builder = XContentFactory.contentBuilder(XContentType.JSON); - entity.toXContent(builder, null); - String xContentResult = Strings.toString(builder); - - assertThat(xContentResult, is(""" - {"messages":[{"role":"user","content":"abc"}],"model":"model","n":1}""")); - } - - public void testXContent_ThrowsIfModelIsNull() { - assertThrows(NullPointerException.class, () -> new OpenAiChatCompletionRequestEntity(List.of("abc"), null, "user", false)); - } - - public void testXContent_ThrowsIfMessagesAreNull() { - assertThrows(NullPointerException.class, () -> new OpenAiChatCompletionRequestEntity(null, "model", "user", false)); - } -} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java new file mode 100644 index 000000000000..f945c154ea23 --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestEntityTests.java @@ -0,0 +1,856 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.external.request.openai; + +import org.elasticsearch.common.Randomness; +import org.elasticsearch.common.Strings; +import org.elasticsearch.inference.UnifiedCompletionRequest; +import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.xcontent.ToXContent; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; +import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModel; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Locale; +import java.util.Map; +import java.util.Random; + +import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests.createChatCompletionModel; +import static org.hamcrest.Matchers.equalTo; + +public class OpenAiUnifiedChatCompletionRequestEntityTests extends ESTestCase { + + // 1. Basic Serialization + // Test with minimal required fields to ensure basic serialization works. + public void testBasicSerialization() throws IOException { + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest(messageList, null, null, null, null, null, null, null); + + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + OpenAiChatCompletionModel model = createChatCompletionModel("test-url", "organizationId", "api-key", "test-endpoint", null); + + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "test-endpoint", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 2. Serialization with All Fields + // Test with all possible fields populated to ensure complete serialization. + public void testSerializationWithAllFields() throws IOException { + // Create a message with all fields populated + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "name", + "tool_call_id", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments", "function_name"), + "type" + ) + ) + ); + + // Create a tool with all fields populated + UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( + "type", + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request with all fields populated + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + "model", + 100L, // maxCompletionTokens + Collections.singletonList("stop"), + 0.9f, // temperature + new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), + Collections.singletonList(tool), + 0.8f // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user", + "name": "name", + "tool_call_id": "tool_call_id", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "arguments", + "name": "function_name" + }, + "type": "type" + } + ] + } + ], + "model": "model-name", + "max_completion_tokens": 100, + "n": 1, + "stop": ["stop"], + "temperature": 0.9, + "tool_choice": "tool_choice", + "tools": [ + { + "type": "type", + "function": { + "description": "Fetches the weather in the given location", + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "description": "The location to get the weather for", + "type": "string" + }, + "unit": { + "description": "The unit to return the temperature in", + "type": "string", + "enum": ["F", "C"] + } + }, + "additionalProperties": false, + "required": ["location", "unit"] + }, + "strict": true + } + } + ], + "top_p": 0.8, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + + } + + // 3. Serialization with Null Optional Fields + // Test with optional fields set to null to ensure they are correctly omitted from the output. + public void testSerializationWithNullOptionalFields() throws IOException { + // Create a message with minimal required fields + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + + // Create the unified request with optional fields set to null + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 4. Serialization with Empty Lists + // Test with fields that are lists set to empty lists to ensure they are correctly serialized. + public void testSerializationWithEmptyLists() throws IOException { + // Create a message with minimal required fields + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + Collections.emptyList() // empty toolCalls list + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request with empty lists + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + Collections.emptyList(), // empty stop list + null, // temperature + null, // toolChoice + Collections.emptyList(), // empty tools list + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user", + "tool_calls": [] + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 5. Serialization with Nested Objects + // Test with nested objects (e.g., toolCalls, toolChoice, tool) to ensure they are correctly serialized. + public void testSerializationWithNestedObjects() throws IOException { + Random random = Randomness.get(); + + // Generate random values + String randomContent = "Hello, world! " + random.nextInt(1000); + String randomName = "name" + random.nextInt(1000); + String randomToolCallId = "tool_call_id" + random.nextInt(1000); + String randomArguments = "arguments" + random.nextInt(1000); + String randomFunctionName = "function_name" + random.nextInt(1000); + String randomType = "type" + random.nextInt(1000); + String randomModel = "model" + random.nextInt(1000); + String randomStop = "stop" + random.nextInt(1000); + float randomTemperature = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); + float randomTopP = (float) ((float) Math.round(0.5d + (double) random.nextFloat() * 0.5d * 100000d) / 100000d); + + // Create a message with nested toolCalls + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(randomContent), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + randomName, + randomToolCallId, + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id", + new UnifiedCompletionRequest.ToolCall.FunctionField(randomArguments, randomFunctionName), + randomType + ) + ) + ); + + // Create a tool with nested function fields + UnifiedCompletionRequest.Tool tool = new UnifiedCompletionRequest.Tool( + randomType, + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request with nested objects + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + randomModel, + 100L, // maxCompletionTokens + Collections.singletonList(randomStop), + randomTemperature, // temperature + new UnifiedCompletionRequest.ToolChoiceObject( + randomType, + new UnifiedCompletionRequest.ToolChoiceObject.FunctionField(randomFunctionName) + ), + Collections.singletonList(tool), + randomTopP // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", randomModel, null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + // Expected JSON should be dynamically generated based on random values + String expectedJson = String.format( + Locale.US, + """ + { + "messages": [ + { + "content": "%s", + "role": "user", + "name": "%s", + "tool_call_id": "%s", + "tool_calls": [ + { + "id": "id", + "function": { + "arguments": "%s", + "name": "%s" + }, + "type": "%s" + } + ] + } + ], + "model": "%s", + "max_completion_tokens": 100, + "n": 1, + "stop": ["%s"], + "temperature": %.5f, + "tool_choice": { + "type": "%s", + "function": { + "name": "%s" + } + }, + "tools": [ + { + "type": "%s", + "function": { + "description": "Fetches the weather in the given location", + "name": "get_weather", + "parameters": { + "type": "object", + "properties": { + "unit": { + "description": "The unit to return the temperature in", + "type": "string", + "enum": ["F", "C"] + }, + "location": { + "description": "The location to get the weather for", + "type": "string" + } + }, + "additionalProperties": false, + "required": ["location", "unit"] + }, + "strict": true + } + } + ], + "top_p": %.5f, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """, + randomContent, + randomName, + randomToolCallId, + randomArguments, + randomFunctionName, + randomType, + randomModel, + randomStop, + randomTemperature, + randomType, + randomFunctionName, + randomType, + randomTopP + ); + assertJsonEquals(jsonString, expectedJson); + } + + // 6. Serialization with Different Content Types + // Test with different content types in messages (e.g., ContentString, ContentObjects) to ensure they are correctly serialized. + public void testSerializationWithDifferentContentTypes() throws IOException { + Random random = Randomness.get(); + + // Generate random values for ContentString + String randomContentString = "Hello, world! " + random.nextInt(1000); + + // Generate random values for ContentObjects + String randomText = "Random text " + random.nextInt(1000); + String randomType = "type" + random.nextInt(1000); + UnifiedCompletionRequest.ContentObject contentObject = new UnifiedCompletionRequest.ContentObject(randomText, randomType); + + var contentObjectsList = new ArrayList(); + contentObjectsList.add(contentObject); + UnifiedCompletionRequest.ContentObjects contentObjects = new UnifiedCompletionRequest.ContentObjects(contentObjectsList); + + // Create messages with different content types + UnifiedCompletionRequest.Message messageWithString = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString(randomContentString), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + + UnifiedCompletionRequest.Message messageWithObjects = new UnifiedCompletionRequest.Message( + contentObjects, + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(messageWithString); + messageList.add(messageWithObjects); + + // Create the unified request with both types of messages + UnifiedCompletionRequest unifiedRequest = UnifiedCompletionRequest.of(messageList); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = String.format(Locale.US, """ + { + "messages": [ + { + "content": "%s", + "role": "user" + }, + { + "content": [ + { + "text": "%s", + "type": "%s" + } + ], + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """, randomContentString, randomText, randomType); + assertJsonEquals(jsonString, expectedJson); + } + + // 7. Serialization with Special Characters + // Test with special characters in string fields to ensure they are correctly escaped and serialized. + public void testSerializationWithSpecialCharacters() throws IOException { + // Create a message with special characters + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world! \n \"Special\" characters: \t \\ /"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "name\nwith\nnewlines", + "tool_call_id\twith\ttabs", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "id\\with\\backslashes", + new UnifiedCompletionRequest.ToolCall.FunctionField("arguments\"with\"quotes", "function_name/with/slashes"), + "type" + ) + ) + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + + // Convert to string and verify + String jsonString = Strings.toString(builder); + String expectedJson = """ + { + "messages": [ + { + "content": "Hello, world! \\n \\"Special\\" characters: \\t \\\\ /", + "role": "user", + "name": "name\\nwith\\nnewlines", + "tool_call_id": "tool_call_id\\twith\\ttabs", + "tool_calls": [ + { + "id": "id\\\\with\\\\backslashes", + "function": { + "arguments": "arguments\\"with\\"quotes", + "name": "function_name/with/slashes" + }, + "type": "type" + } + ] + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(jsonString, expectedJson); + } + + // 8. Serialization with Boolean Fields + // Test with boolean fields (stream) set to both true and false to ensure they are correctly serialized. + public void testSerializationWithBooleanFields() throws IOException { + // Create a message with minimal required fields + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Hello, world!"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Test with stream set to true + UnifiedChatInput unifiedChatInputTrue = new UnifiedChatInput(unifiedRequest, true); + OpenAiUnifiedChatCompletionRequestEntity entityTrue = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInputTrue, model); + + XContentBuilder builderTrue = JsonXContent.contentBuilder(); + entityTrue.toXContent(builderTrue, ToXContent.EMPTY_PARAMS); + + String jsonStringTrue = Strings.toString(builderTrue); + String expectedJsonTrue = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": true, + "stream_options": { + "include_usage": true + } + } + """; + assertJsonEquals(expectedJsonTrue, jsonStringTrue); + + // Test with stream set to false + UnifiedChatInput unifiedChatInputFalse = new UnifiedChatInput(unifiedRequest, false); + OpenAiUnifiedChatCompletionRequestEntity entityFalse = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInputFalse, model); + + XContentBuilder builderFalse = JsonXContent.contentBuilder(); + entityFalse.toXContent(builderFalse, ToXContent.EMPTY_PARAMS); + + String jsonStringFalse = Strings.toString(builderFalse); + String expectedJsonFalse = """ + { + "messages": [ + { + "content": "Hello, world!", + "role": "user" + } + ], + "model": "model-name", + "n": 1, + "stream": false + } + """; + assertJsonEquals(expectedJsonFalse, jsonStringFalse); + } + + // 9. Serialization with Missing Required Fields + // Test with missing required fields to ensure appropriate exceptions are thrown. + public void testSerializationWithMissingRequiredFields() { + // Create a message with missing content (required field) + UnifiedCompletionRequest.Message message = new UnifiedCompletionRequest.Message( + null, // missing content + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + null, + null, + null + ); + var messageList = new ArrayList(); + messageList.add(message); + // Create the unified request + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + null, // model + null, // maxCompletionTokens + null, // stop + null, // temperature + null, // toolChoice + null, // tools + null // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Attempt to serialize to XContent and expect an exception + try { + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + fail("Expected an exception due to missing required fields"); + } catch (NullPointerException | IOException e) { + // Expected exception + } + } + + // 10. Serialization with Mixed Valid and Invalid Data + // Test with a mix of valid and invalid data to ensure the serializer handles it gracefully. + public void testSerializationWithMixedValidAndInvalidData() throws IOException { + // Create a valid message + UnifiedCompletionRequest.Message validMessage = new UnifiedCompletionRequest.Message( + new UnifiedCompletionRequest.ContentString("Valid content"), + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "validName", + "validToolCallId", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "validId", + new UnifiedCompletionRequest.ToolCall.FunctionField("validArguments", "validFunctionName"), + "validType" + ) + ) + ); + + // Create an invalid message with null content + UnifiedCompletionRequest.Message invalidMessage = new UnifiedCompletionRequest.Message( + null, // invalid content + OpenAiUnifiedChatCompletionRequestEntity.USER_FIELD, + "invalidName", + "invalidToolCallId", + Collections.singletonList( + new UnifiedCompletionRequest.ToolCall( + "invalidId", + new UnifiedCompletionRequest.ToolCall.FunctionField("invalidArguments", "invalidFunctionName"), + "invalidType" + ) + ) + ); + var messageList = new ArrayList(); + messageList.add(validMessage); + messageList.add(invalidMessage); + // Create the unified request with both valid and invalid messages + UnifiedCompletionRequest unifiedRequest = new UnifiedCompletionRequest( + messageList, + "model-name", + 100L, // maxCompletionTokens + Collections.singletonList("stop"), + 0.9f, // temperature + new UnifiedCompletionRequest.ToolChoiceString("tool_choice"), + Collections.singletonList( + new UnifiedCompletionRequest.Tool( + "type", + new UnifiedCompletionRequest.Tool.FunctionField( + "Fetches the weather in the given location", + "get_weather", + createParameters(), + true + ) + ) + ), + 0.8f // topP + ); + + // Create the unified chat input + UnifiedChatInput unifiedChatInput = new UnifiedChatInput(unifiedRequest, true); + + OpenAiChatCompletionModel model = createChatCompletionModel("test-endpoint", "organizationId", "api-key", "model-name", null); + + // Create the entity + OpenAiUnifiedChatCompletionRequestEntity entity = new OpenAiUnifiedChatCompletionRequestEntity(unifiedChatInput, model); + + // Serialize to XContent and verify + try { + XContentBuilder builder = JsonXContent.contentBuilder(); + entity.toXContent(builder, ToXContent.EMPTY_PARAMS); + fail("Expected an exception due to invalid data"); + } catch (NullPointerException | IOException e) { + // Expected exception + } + } + + public static Map createParameters() { + Map parameters = new LinkedHashMap<>(); + parameters.put("type", "object"); + + Map properties = new HashMap<>(); + + Map location = new HashMap<>(); + location.put("type", "string"); + location.put("description", "The location to get the weather for"); + properties.put("location", location); + + Map unit = new HashMap<>(); + unit.put("type", "string"); + unit.put("description", "The unit to return the temperature in"); + unit.put("enum", new String[] { "F", "C" }); + properties.put("unit", unit); + + parameters.put("properties", properties); + parameters.put("additionalProperties", false); + parameters.put("required", new String[] { "location", "unit" }); + + return parameters; + } + + private void assertJsonEquals(String actual, String expected) throws IOException { + try ( + var actualParser = createParser(JsonXContent.jsonXContent, actual); + var expectedParser = createParser(JsonXContent.jsonXContent, expected) + ) { + assertThat(actualParser.mapOrdered(), equalTo(expectedParser.mapOrdered())); + } + } + +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java similarity index 75% rename from x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java rename to x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java index b6ebfd02941f..2be12c9b12e0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiChatCompletionRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/request/openai/OpenAiUnifiedChatCompletionRequestTests.java @@ -12,6 +12,7 @@ import org.elasticsearch.core.Nullable; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionModelTests; import java.io.IOException; @@ -20,16 +21,16 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; -import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiChatCompletionRequest.buildDefaultUri; +import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUnifiedChatCompletionRequest.buildDefaultUri; import static org.elasticsearch.xpack.inference.external.request.openai.OpenAiUtils.ORGANIZATION_HEADER; import static org.hamcrest.Matchers.aMapWithSize; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; -public class OpenAiChatCompletionRequestTests extends ESTestCase { +public class OpenAiUnifiedChatCompletionRequestTests extends ESTestCase { public void testCreateRequest_WithUrlOrganizationUserDefined() throws IOException { - var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user"); + var request = createRequest("www.google.com", "org", "secret", "abc", "model", "user", true); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -41,15 +42,27 @@ public void testCreateRequest_WithUrlOrganizationUserDefined() throws IOExceptio assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org")); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(4)); + assertRequestMapWithUser(requestMap, "user"); + } + + private void assertRequestMapWithoutUser(Map requestMap) { + assertRequestMapWithUser(requestMap, null); + } + + private void assertRequestMapWithUser(Map requestMap, @Nullable String user) { + assertThat(requestMap, aMapWithSize(user != null ? 6 : 5)); assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); assertThat(requestMap.get("model"), is("model")); - assertThat(requestMap.get("user"), is("user")); + if (user != null) { + assertThat(requestMap.get("user"), is(user)); + } assertThat(requestMap.get("n"), is(1)); + assertTrue((Boolean) requestMap.get("stream")); + assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true))); } public void testCreateRequest_WithDefaultUrl() throws URISyntaxException, IOException { - var request = createRequest(null, "org", "secret", "abc", "model", "user"); + var request = createRequest(null, "org", "secret", "abc", "model", "user", true); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); @@ -61,33 +74,27 @@ public void testCreateRequest_WithDefaultUrl() throws URISyntaxException, IOExce assertThat(httpPost.getLastHeader(ORGANIZATION_HEADER).getValue(), is("org")); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(4)); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); - assertThat(requestMap.get("model"), is("model")); - assertThat(requestMap.get("user"), is("user")); - assertThat(requestMap.get("n"), is(1)); + assertRequestMapWithUser(requestMap, "user"); + } public void testCreateRequest_WithDefaultUrlAndWithoutUserOrganization() throws URISyntaxException, IOException { - var request = createRequest(null, null, "secret", "abc", "model", null); + var request = createRequest(null, null, "secret", "abc", "model", null, true); var httpRequest = request.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); - assertThat(httpPost.getURI().toString(), is(OpenAiChatCompletionRequest.buildDefaultUri().toString())); + assertThat(httpPost.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString())); assertThat(httpPost.getLastHeader(HttpHeaders.CONTENT_TYPE).getValue(), is(XContentType.JSON.mediaType())); assertThat(httpPost.getLastHeader(HttpHeaders.AUTHORIZATION).getValue(), is("Bearer secret")); assertNull(httpPost.getLastHeader(ORGANIZATION_HEADER)); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(3)); - assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abc")))); - assertThat(requestMap.get("model"), is("model")); - assertThat(requestMap.get("n"), is(1)); + assertRequestMapWithoutUser(requestMap); } - public void testCreateRequest_WithStreaming() throws URISyntaxException, IOException { + public void testCreateRequest_WithStreaming() throws IOException { var request = createRequest(null, null, "secret", "abc", "model", null, true); var httpRequest = request.createHttpRequest(); @@ -99,29 +106,31 @@ public void testCreateRequest_WithStreaming() throws URISyntaxException, IOExcep } public void testTruncate_DoesNotReduceInputTextSize() throws URISyntaxException, IOException { - var request = createRequest(null, null, "secret", "abcd", "model", null); + var request = createRequest(null, null, "secret", "abcd", "model", null, true); var truncatedRequest = request.truncate(); - assertThat(request.getURI().toString(), is(OpenAiChatCompletionRequest.buildDefaultUri().toString())); + assertThat(request.getURI().toString(), is(OpenAiUnifiedChatCompletionRequest.buildDefaultUri().toString())); var httpRequest = truncatedRequest.createHttpRequest(); assertThat(httpRequest.httpRequestBase(), instanceOf(HttpPost.class)); var httpPost = (HttpPost) httpRequest.httpRequestBase(); var requestMap = entityAsMap(httpPost.getEntity().getContent()); - assertThat(requestMap, aMapWithSize(3)); + assertThat(requestMap, aMapWithSize(5)); // We do not truncate for OpenAi chat completions assertThat(requestMap.get("messages"), is(List.of(Map.of("role", "user", "content", "abcd")))); assertThat(requestMap.get("model"), is("model")); assertThat(requestMap.get("n"), is(1)); + assertTrue((Boolean) requestMap.get("stream")); + assertThat(requestMap.get("stream_options"), is(Map.of("include_usage", true))); } public void testTruncationInfo_ReturnsNull() { - var request = createRequest(null, null, "secret", "abcd", "model", null); + var request = createRequest(null, null, "secret", "abcd", "model", null, true); assertNull(request.getTruncationInfo()); } - public static OpenAiChatCompletionRequest createRequest( + public static OpenAiUnifiedChatCompletionRequest createRequest( @Nullable String url, @Nullable String org, String apiKey, @@ -132,7 +141,7 @@ public static OpenAiChatCompletionRequest createRequest( return createRequest(url, org, apiKey, input, model, user, false); } - public static OpenAiChatCompletionRequest createRequest( + public static OpenAiUnifiedChatCompletionRequest createRequest( @Nullable String url, @Nullable String org, String apiKey, @@ -142,7 +151,7 @@ public static OpenAiChatCompletionRequest createRequest( boolean stream ) { var chatCompletionModel = OpenAiChatCompletionModelTests.createChatCompletionModel(url, org, apiKey, model, user); - return new OpenAiChatCompletionRequest(List.of(input), chatCompletionModel, stream); + return new OpenAiUnifiedChatCompletionRequest(new UnifiedChatInput(List.of(input), "user", stream), chatCompletionModel); } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java new file mode 100644 index 000000000000..7dc4d99e06ac --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/highlight/SemanticTextHighlighterTests.java @@ -0,0 +1,288 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.highlight; + +import org.apache.lucene.analysis.standard.StandardAnalyzer; +import org.apache.lucene.index.DirectoryReader; +import org.apache.lucene.index.IndexWriterConfig; +import org.apache.lucene.search.IndexSearcher; +import org.apache.lucene.search.Sort; +import org.apache.lucene.search.TopDocs; +import org.apache.lucene.search.join.ScoreMode; +import org.apache.lucene.store.Directory; +import org.apache.lucene.tests.index.RandomIndexWriter; +import org.apache.lucene.util.BytesRef; +import org.elasticsearch.action.OriginalIndices; +import org.elasticsearch.action.search.SearchRequest; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.bytes.BytesReference; +import org.elasticsearch.common.io.Streams; +import org.elasticsearch.common.lucene.search.Queries; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.xcontent.XContentHelper; +import org.elasticsearch.index.IndexVersion; +import org.elasticsearch.index.mapper.MapperService; +import org.elasticsearch.index.mapper.MapperServiceTestCase; +import org.elasticsearch.index.mapper.SourceToParse; +import org.elasticsearch.index.query.NestedQueryBuilder; +import org.elasticsearch.index.query.QueryBuilder; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.search.SearchHit; +import org.elasticsearch.search.builder.SearchSourceBuilder; +import org.elasticsearch.search.fetch.FetchContext; +import org.elasticsearch.search.fetch.FetchSubPhase; +import org.elasticsearch.search.fetch.subphase.highlight.FieldHighlightContext; +import org.elasticsearch.search.fetch.subphase.highlight.HighlightBuilder; +import org.elasticsearch.search.fetch.subphase.highlight.SearchHighlightContext; +import org.elasticsearch.search.internal.AliasFilter; +import org.elasticsearch.search.internal.ShardSearchRequest; +import org.elasticsearch.search.lookup.Source; +import org.elasticsearch.search.rank.RankDoc; +import org.elasticsearch.search.vectors.KnnVectorQueryBuilder; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryBuilder; +import org.elasticsearch.xpack.core.ml.search.WeightedToken; +import org.elasticsearch.xpack.inference.InferencePlugin; +import org.elasticsearch.xpack.inference.mapper.SemanticTextFieldMapper; +import org.junit.Before; +import org.mockito.Mockito; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.zip.GZIPInputStream; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.mockito.Mockito.mock; + +public class SemanticTextHighlighterTests extends MapperServiceTestCase { + private static final String SEMANTIC_FIELD_E5 = "body-e5"; + private static final String SEMANTIC_FIELD_ELSER = "body-elser"; + + private Map queries; + + @Override + protected Collection getPlugins() { + return List.of(new InferencePlugin(Settings.EMPTY)); + } + + @Override + @Before + public void setUp() throws Exception { + super.setUp(); + var input = Streams.readFully(SemanticTextHighlighterTests.class.getResourceAsStream("queries.json")); + this.queries = XContentHelper.convertToMap(input, false, XContentType.JSON).v2(); + } + + @SuppressWarnings("unchecked") + public void testDenseVector() throws Exception { + var mapperService = createDefaultMapperService(); + Map queryMap = (Map) queries.get("dense_vector_1"); + float[] vector = readDenseVector(queryMap.get("embeddings")); + var fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) mapperService.mappingLookup().getFieldType(SEMANTIC_FIELD_E5); + KnnVectorQueryBuilder knnQuery = new KnnVectorQueryBuilder(fieldType.getEmbeddingsField().fullPath(), vector, 10, 10, null); + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(fieldType.getChunksField().fullPath(), knnQuery, ScoreMode.Max); + var shardRequest = createShardSearchRequest(nestedQueryBuilder); + var sourceToParse = new SourceToParse("0", readSampleDoc("sample-doc.json.gz"), XContentType.JSON); + + String[] expectedScorePassages = ((List) queryMap.get("expected_by_score")).toArray(String[]::new); + for (int i = 0; i < expectedScorePassages.length; i++) { + assertHighlightOneDoc( + mapperService, + shardRequest, + sourceToParse, + SEMANTIC_FIELD_E5, + i + 1, + HighlightBuilder.Order.SCORE, + Arrays.copyOfRange(expectedScorePassages, 0, i + 1) + ); + } + + String[] expectedOffsetPassages = ((List) queryMap.get("expected_by_offset")).toArray(String[]::new); + assertHighlightOneDoc( + mapperService, + shardRequest, + sourceToParse, + SEMANTIC_FIELD_E5, + expectedOffsetPassages.length, + HighlightBuilder.Order.NONE, + expectedOffsetPassages + ); + } + + @SuppressWarnings("unchecked") + public void testSparseVector() throws Exception { + var mapperService = createDefaultMapperService(); + Map queryMap = (Map) queries.get("sparse_vector_1"); + List tokens = readSparseVector(queryMap.get("embeddings")); + var fieldType = (SemanticTextFieldMapper.SemanticTextFieldType) mapperService.mappingLookup().getFieldType(SEMANTIC_FIELD_ELSER); + SparseVectorQueryBuilder sparseQuery = new SparseVectorQueryBuilder( + fieldType.getEmbeddingsField().fullPath(), + tokens, + null, + null, + null, + null + ); + NestedQueryBuilder nestedQueryBuilder = new NestedQueryBuilder(fieldType.getChunksField().fullPath(), sparseQuery, ScoreMode.Max); + var shardRequest = createShardSearchRequest(nestedQueryBuilder); + var sourceToParse = new SourceToParse("0", readSampleDoc("sample-doc.json.gz"), XContentType.JSON); + + String[] expectedScorePassages = ((List) queryMap.get("expected_by_score")).toArray(String[]::new); + for (int i = 0; i < expectedScorePassages.length; i++) { + assertHighlightOneDoc( + mapperService, + shardRequest, + sourceToParse, + SEMANTIC_FIELD_ELSER, + i + 1, + HighlightBuilder.Order.SCORE, + Arrays.copyOfRange(expectedScorePassages, 0, i + 1) + ); + } + + String[] expectedOffsetPassages = ((List) queryMap.get("expected_by_offset")).toArray(String[]::new); + assertHighlightOneDoc( + mapperService, + shardRequest, + sourceToParse, + SEMANTIC_FIELD_ELSER, + expectedOffsetPassages.length, + HighlightBuilder.Order.NONE, + expectedOffsetPassages + ); + } + + private MapperService createDefaultMapperService() throws IOException { + var mappings = Streams.readFully(SemanticTextHighlighterTests.class.getResourceAsStream("mappings.json")); + return createMapperService(mappings.utf8ToString()); + } + + private float[] readDenseVector(Object value) { + if (value instanceof List lst) { + float[] res = new float[lst.size()]; + int pos = 0; + for (var obj : lst) { + if (obj instanceof Number number) { + res[pos++] = number.floatValue(); + } else { + throw new IllegalArgumentException("Expected number, got " + obj.getClass().getSimpleName()); + } + } + return res; + } + throw new IllegalArgumentException("Expected list, got " + value.getClass().getSimpleName()); + } + + private List readSparseVector(Object value) { + if (value instanceof Map map) { + List res = new ArrayList<>(); + for (var entry : map.entrySet()) { + if (entry.getValue() instanceof Number number) { + res.add(new WeightedToken((String) entry.getKey(), number.floatValue())); + } else { + throw new IllegalArgumentException("Expected number, got " + entry.getValue().getClass().getSimpleName()); + } + } + return res; + } + throw new IllegalArgumentException("Expected map, got " + value.getClass().getSimpleName()); + } + + private void assertHighlightOneDoc( + MapperService mapperService, + ShardSearchRequest request, + SourceToParse source, + String fieldName, + int numFragments, + HighlightBuilder.Order order, + String[] expectedPassages + ) throws Exception { + SemanticTextFieldMapper fieldMapper = (SemanticTextFieldMapper) mapperService.mappingLookup().getMapper(fieldName); + var doc = mapperService.documentMapper().parse(source); + assertNull(doc.dynamicMappingsUpdate()); + try (Directory dir = newDirectory()) { + IndexWriterConfig iwc = newIndexWriterConfig(new StandardAnalyzer()); + RandomIndexWriter iw = new RandomIndexWriter(random(), dir, iwc); + iw.addDocuments(doc.docs()); + try (DirectoryReader reader = wrapInMockESDirectoryReader(iw.getReader())) { + IndexSearcher searcher = newSearcher(reader); + iw.close(); + TopDocs topDocs = searcher.search(Queries.newNonNestedFilter(IndexVersion.current()), 1, Sort.INDEXORDER); + assertThat(topDocs.totalHits.value(), equalTo(1L)); + int docID = topDocs.scoreDocs[0].doc; + SemanticTextHighlighter highlighter = new SemanticTextHighlighter(); + var execContext = createSearchExecutionContext(mapperService); + var luceneQuery = execContext.toQuery(request.source().query()).query(); + FetchContext fetchContext = mock(FetchContext.class); + Mockito.when(fetchContext.highlight()).thenReturn(new SearchHighlightContext(Collections.emptyList())); + Mockito.when(fetchContext.query()).thenReturn(luceneQuery); + Mockito.when(fetchContext.getSearchExecutionContext()).thenReturn(execContext); + + FetchSubPhase.HitContext hitContext = new FetchSubPhase.HitContext( + new SearchHit(docID), + getOnlyLeafReader(reader).getContext(), + docID, + Map.of(), + Source.fromBytes(source.source()), + new RankDoc(docID, Float.NaN, 0) + ); + try { + var highlightContext = new HighlightBuilder().field(fieldName, 0, numFragments) + .order(order) + .highlighterType(SemanticTextHighlighter.NAME) + .build(execContext); + + for (var fieldContext : highlightContext.fields()) { + FieldHighlightContext context = new FieldHighlightContext( + fieldName, + fieldContext, + fieldMapper.fieldType(), + fetchContext, + hitContext, + luceneQuery, + new HashMap<>() + ); + var result = highlighter.highlight(context); + assertThat(result.fragments().length, equalTo(expectedPassages.length)); + for (int i = 0; i < result.fragments().length; i++) { + assertThat(result.fragments()[i].string(), equalTo(expectedPassages[i])); + } + } + } finally { + hitContext.hit().decRef(); + } + } + } + } + + private SearchRequest createSearchRequest(QueryBuilder queryBuilder) { + SearchRequest request = new SearchRequest(); + request.source(new SearchSourceBuilder()); + request.allowPartialSearchResults(false); + request.source().query(queryBuilder); + return request; + } + + private ShardSearchRequest createShardSearchRequest(QueryBuilder queryBuilder) { + SearchRequest request = createSearchRequest(queryBuilder); + return new ShardSearchRequest(OriginalIndices.NONE, request, new ShardId("index", "index", 0), 0, 1, AliasFilter.EMPTY, 1, 0, null); + } + + private BytesReference readSampleDoc(String fileName) throws IOException { + try (var in = new GZIPInputStream(SemanticTextHighlighterTests.class.getResourceAsStream(fileName))) { + return new BytesArray(new BytesRef(in.readAllBytes())); + } + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java index fd60d9687f43..c6a492dfcf4e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/mapper/SemanticTextFieldMapperTests.java @@ -61,6 +61,7 @@ import org.elasticsearch.xcontent.XContentParser; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.model.TestModel; import org.junit.AssumptionViolatedException; @@ -1110,7 +1111,12 @@ private static Query generateNestedTermSparseVectorQuery(NestedLookup nestedLook } queryBuilder.add(new BooleanClause(mapper.nestedTypeFilter(), BooleanClause.Occur.FILTER)); - return new ESToParentBlockJoinQuery(queryBuilder.build(), parentFilter, ScoreMode.Total, null); + return new ESToParentBlockJoinQuery( + new SparseVectorQueryWrapper(fieldName, queryBuilder.build()), + parentFilter, + ScoreMode.Total, + null + ); } private static void assertChildLeafNestedDocument( diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java index b8bcb766b53e..36aa2200ecea 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/queries/SemanticQueryBuilderTests.java @@ -45,12 +45,14 @@ import org.elasticsearch.xcontent.XContentBuilder; import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xcontent.json.JsonXContent; +import org.elasticsearch.xpack.core.XPackClientPlugin; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.core.inference.results.InferenceTextEmbeddingFloatResults; import org.elasticsearch.xpack.core.inference.results.SparseEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.MlInferenceNamedXContentProvider; import org.elasticsearch.xpack.core.ml.inference.results.MlTextEmbeddingResults; import org.elasticsearch.xpack.core.ml.inference.results.TextExpansionResults; +import org.elasticsearch.xpack.core.ml.search.SparseVectorQueryWrapper; import org.elasticsearch.xpack.core.ml.search.WeightedToken; import org.elasticsearch.xpack.inference.InferencePlugin; import org.elasticsearch.xpack.inference.mapper.SemanticTextField; @@ -114,7 +116,7 @@ public void setUp() throws Exception { @Override protected Collection> getPlugins() { - return List.of(InferencePlugin.class, FakeMlPlugin.class); + return List.of(XPackClientPlugin.class, InferencePlugin.class, FakeMlPlugin.class); } @Override @@ -194,9 +196,11 @@ protected void doAssertLuceneQuery(SemanticQueryBuilder queryBuilder, Query quer private void assertSparseEmbeddingLuceneQuery(Query query) { Query innerQuery = assertOuterBooleanQuery(query); - assertThat(innerQuery, instanceOf(BooleanQuery.class)); + assertThat(innerQuery, instanceOf(SparseVectorQueryWrapper.class)); + var sparseQuery = (SparseVectorQueryWrapper) innerQuery; + assertThat(((SparseVectorQueryWrapper) innerQuery).getTermsQuery(), instanceOf(BooleanQuery.class)); - BooleanQuery innerBooleanQuery = (BooleanQuery) innerQuery; + BooleanQuery innerBooleanQuery = (BooleanQuery) sparseQuery.getTermsQuery(); assertThat(innerBooleanQuery.clauses().size(), equalTo(queryTokenCount)); innerBooleanQuery.forEach(c -> { assertThat(c.occur(), equalTo(SHOULD)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java index 05a8d52be5df..5528c80066b0 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/BaseInferenceActionTests.java @@ -8,11 +8,14 @@ package org.elasticsearch.xpack.inference.rest; import org.apache.lucene.util.SetOnce; +import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.common.bytes.BytesArray; import org.elasticsearch.core.TimeValue; +import org.elasticsearch.inference.TaskType; import org.elasticsearch.rest.RestChannel; import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestRequestTests; import org.elasticsearch.rest.action.RestChunkedToXContentListener; import org.elasticsearch.test.rest.FakeRestRequest; import org.elasticsearch.test.rest.RestActionTestCase; @@ -26,6 +29,10 @@ import java.util.Map; import static org.elasticsearch.rest.RestRequest.Method.POST; +import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseParams; +import static org.elasticsearch.xpack.inference.rest.BaseInferenceAction.parseTimeout; +import static org.elasticsearch.xpack.inference.rest.Paths.INFERENCE_ID; +import static org.elasticsearch.xpack.inference.rest.Paths.TASK_TYPE_OR_INFERENCE_ID; import static org.hamcrest.CoreMatchers.is; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.instanceOf; @@ -56,6 +63,42 @@ private static String route(String param) { return "_route/" + param; } + public void testParseParams_ExtractsInferenceIdAndTaskType() { + var params = parseParams( + RestRequestTests.contentRestRequest("{}", Map.of(INFERENCE_ID, "id", TASK_TYPE_OR_INFERENCE_ID, TaskType.COMPLETION.toString())) + ); + assertThat(params, is(new BaseInferenceAction.Params("id", TaskType.COMPLETION))); + } + + public void testParseParams_DefaultsToTaskTypeAny_WhenInferenceId_IsMissing() { + var params = parseParams( + RestRequestTests.contentRestRequest("{}", Map.of(TASK_TYPE_OR_INFERENCE_ID, TaskType.COMPLETION.toString())) + ); + assertThat(params, is(new BaseInferenceAction.Params("completion", TaskType.ANY))); + } + + public void testParseParams_ThrowsStatusException_WhenTaskTypeIsMissing() { + var e = expectThrows( + ElasticsearchStatusException.class, + () -> parseParams(RestRequestTests.contentRestRequest("{}", Map.of(INFERENCE_ID, "id"))) + ); + assertThat(e.getMessage(), is("Task type must not be null")); + } + + public void testParseTimeout_ReturnsTimeout() { + var timeout = parseTimeout( + RestRequestTests.contentRestRequest("{}", Map.of(InferenceAction.Request.TIMEOUT.getPreferredName(), "4s")) + ); + + assertThat(timeout, is(TimeValue.timeValueSeconds(4))); + } + + public void testParseTimeout_ReturnsDefaultTimeout() { + var timeout = parseTimeout(RestRequestTests.contentRestRequest("{}", Map.of())); + + assertThat(timeout, is(TimeValue.timeValueSeconds(30))); + } + public void testUsesDefaultTimeout() { SetOnce executeCalled = new SetOnce<>(); verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java new file mode 100644 index 000000000000..5acfe67b175d --- /dev/null +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/rest/RestUnifiedCompletionInferenceActionTests.java @@ -0,0 +1,81 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.inference.rest; + +import org.apache.lucene.util.SetOnce; +import org.elasticsearch.common.bytes.BytesArray; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.util.concurrent.ThreadContext; +import org.elasticsearch.rest.AbstractRestChannel; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; +import org.elasticsearch.test.rest.FakeRestRequest; +import org.elasticsearch.test.rest.RestActionTestCase; +import org.elasticsearch.xcontent.XContentType; +import org.elasticsearch.xpack.core.inference.action.UnifiedCompletionAction; +import org.junit.Before; + +import static org.elasticsearch.xpack.inference.rest.BaseInferenceActionTests.createResponse; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.instanceOf; + +public class RestUnifiedCompletionInferenceActionTests extends RestActionTestCase { + + @Before + public void setUpAction() { + controller().registerHandler(new RestUnifiedCompletionInferenceAction()); + } + + public void testStreamIsTrue() { + SetOnce executeCalled = new SetOnce<>(); + verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> { + assertThat(actionRequest, instanceOf(UnifiedCompletionAction.Request.class)); + + var request = (UnifiedCompletionAction.Request) actionRequest; + assertThat(request.isStreaming(), is(true)); + + executeCalled.set(true); + return createResponse(); + })); + + var requestBody = """ + { + "messages": [ + { + "content": "abc", + "role": "user" + } + ] + } + """; + + RestRequest inferenceRequest = new FakeRestRequest.Builder(xContentRegistry()).withMethod(RestRequest.Method.POST) + .withPath("_inference/completion/test/_unified") + .withContent(new BytesArray(requestBody), XContentType.JSON) + .build(); + + final SetOnce responseSetOnce = new SetOnce<>(); + dispatchRequest(inferenceRequest, new AbstractRestChannel(inferenceRequest, true) { + @Override + public void sendResponse(RestResponse response) { + responseSetOnce.set(response); + } + }); + + // the response content will be null when there is no error + assertNull(responseSetOnce.get().content()); + assertThat(executeCalled.get(), equalTo(true)); + } + + private void dispatchRequest(final RestRequest request, final RestChannel channel) { + ThreadContext threadContext = new ThreadContext(Settings.EMPTY); + controller().dispatchRequest(request, channel, threadContext); + } +} diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java index 47a96bf78dda..6768583598b2 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/SenderServiceTests.java @@ -26,6 +26,7 @@ import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.InferenceInputs; import org.elasticsearch.xpack.inference.external.http.sender.Sender; +import org.elasticsearch.xpack.inference.external.http.sender.UnifiedChatInput; import org.junit.After; import org.junit.Before; @@ -119,6 +120,14 @@ protected void doInfer( } + @Override + protected void doUnifiedCompletionInfer( + Model model, + UnifiedChatInput inputs, + TimeValue timeout, + ActionListener listener + ) {} + @Override protected void doChunkedInfer( Model model, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index 76b5d6fee2c5..159b77789482 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -27,6 +27,7 @@ import org.elasticsearch.inference.Model; import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; @@ -920,6 +921,68 @@ public void testInfer_SendsRequest() throws IOException { } } + public void testUnifiedCompletionInfer() throws Exception { + // The escapes are because the streaming response must be on a single line + String responseJson = """ + data: {\ + "id":"12345",\ + "object":"chat.completion.chunk",\ + "created":123456789,\ + "model":"gpt-4o-mini",\ + "system_fingerprint": "123456789",\ + "choices":[\ + {\ + "index":0,\ + "delta":{\ + "content":"hello, world"\ + },\ + "logprobs":null,\ + "finish_reason":"stop"\ + }\ + ],\ + "usage":{\ + "prompt_tokens": 16,\ + "completion_tokens": 28,\ + "total_tokens": 44,\ + "prompt_tokens_details": {\ + "cached_tokens": 0,\ + "audio_tokens": 0\ + },\ + "completion_tokens_details": {\ + "reasoning_tokens": 0,\ + "audio_tokens": 0,\ + "accepted_prediction_tokens": 0,\ + "rejected_prediction_tokens": 0\ + }\ + }\ + } + + """; + webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); + + var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); + try (var service = new OpenAiService(senderFactory, createWithEmptySettings(threadPool))) { + var model = OpenAiChatCompletionModelTests.createChatCompletionModel(getUrl(webServer), "org", "secret", "model", "user"); + PlainActionFuture listener = new PlainActionFuture<>(); + service.unifiedCompletionInfer( + model, + UnifiedCompletionRequest.of( + List.of( + new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "user", null, null, null) + ) + ), + InferenceAction.Request.DEFAULT_TIMEOUT, + listener + ); + + var result = listener.actionGet(TIMEOUT); + InferenceEventsAssertion.assertThat(result).hasFinishedStream().hasNoErrors().hasEvent(""" + {"id":"12345","choices":[{"delta":{"content":"hello, world"},"finish_reason":"stop","index":0}],""" + """ + "model":"gpt-4o-mini","object":"chat.completion.chunk",""" + """ + "usage":{"completion_tokens":28,"prompt_tokens":16,"total_tokens":44}}"""); + } + } + public void testInfer_StreamRequest() throws Exception { String responseJson = """ data: {\ diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java index ab1786f0a584..e7ac4cf879e9 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/completion/OpenAiChatCompletionModelTests.java @@ -10,9 +10,11 @@ import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.inference.TaskType; +import org.elasticsearch.inference.UnifiedCompletionRequest; import org.elasticsearch.test.ESTestCase; import org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettings; +import java.util.List; import java.util.Map; import static org.elasticsearch.xpack.inference.services.openai.completion.OpenAiChatCompletionRequestTaskSettingsTests.getChatCompletionRequestTaskSettingsMap; @@ -42,10 +44,48 @@ public void testOverrideWith_EmptyMap() { public void testOverrideWith_NullMap() { var model = createChatCompletionModel("url", "org", "api_key", "model_name", null); - var overriddenModel = OpenAiChatCompletionModel.of(model, null); + var overriddenModel = OpenAiChatCompletionModel.of(model, (Map) null); assertThat(overriddenModel, sameInstance(model)); } + public void testOverrideWith_UnifiedCompletionRequest_OverridesModelId() { + var model = createChatCompletionModel("url", "org", "api_key", "model_name", "user"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null, null)), + "different_model", + null, + null, + null, + null, + null, + null + ); + + assertThat( + OpenAiChatCompletionModel.of(model, request), + is(createChatCompletionModel("url", "org", "api_key", "different_model", "user")) + ); + } + + public void testOverrideWith_UnifiedCompletionRequest_UsesModelFields_WhenRequestDoesNotOverride() { + var model = createChatCompletionModel("url", "org", "api_key", "model_name", "user"); + var request = new UnifiedCompletionRequest( + List.of(new UnifiedCompletionRequest.Message(new UnifiedCompletionRequest.ContentString("hello"), "role", null, null, null)), + null, // not overriding model + null, + null, + null, + null, + null, + null + ); + + assertThat( + OpenAiChatCompletionModel.of(model, request), + is(createChatCompletionModel("url", "org", "api_key", "model_name", "user")) + ); + } + public static OpenAiChatCompletionModel createChatCompletionModel( String url, @Nullable String org, diff --git a/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/mappings.json b/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/mappings.json new file mode 100644 index 000000000000..9841ee0aed6e --- /dev/null +++ b/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/mappings.json @@ -0,0 +1,27 @@ +{ + "_doc": { + "properties": { + "body": { + "type": "text", + "copy_to": ["body-elser", "body-e5"] + }, + "body-e5": { + "type": "semantic_text", + "inference_id": ".multilingual-e5-small-elasticsearch", + "model_settings": { + "task_type": "text_embedding", + "dimensions": 384, + "similarity": "cosine", + "element_type": "float" + } + }, + "body-elser": { + "type": "semantic_text", + "inference_id": ".elser-2-elasticsearch", + "model_settings": { + "task_type": "sparse_embedding" + } + } + } + } +} \ No newline at end of file diff --git a/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/queries.json b/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/queries.json new file mode 100644 index 000000000000..6227f3f49885 --- /dev/null +++ b/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/queries.json @@ -0,0 +1,467 @@ +{ + "dense_vector_1": { + "embeddings": [ + 0.09475211, + 0.044564713, + -0.04378501, + -0.07908551, + 0.04332011, + -0.03891992, + -0.0062305215, + 0.024245035, + -0.008976331, + 0.032832284, + 0.052760173, + 0.008123907, + 0.09049037, + -0.01637332, + -0.054353267, + 0.00771307, + 0.08545496, + -0.079716265, + -0.045666866, + -0.04369993, + 0.009189822, + -0.013782891, + -0.07701858, + 0.037278354, + 0.049807206, + 0.078036495, + -0.059533164, + 0.051413406, + 0.040234447, + -0.038139492, + -0.085189626, + -0.045546446, + 0.0544375, + -0.05604156, + 0.057408098, + 0.041913517, + -0.037348013, + -0.025998272, + 0.08486864, + -0.046678443, + 0.0041820924, + 0.007514462, + 0.06424746, + 0.044233218, + 0.103267275, + 0.014130771, + -0.049954403, + 0.04226959, + -0.08346965, + -0.01639249, + -0.060537644, + 0.04546336, + 0.012866155, + 0.05375096, + 0.036775924, + -0.0762226, + -0.037304543, + -0.05692274, + -0.055807598, + 0.0040082196, + 0.059259634, + 0.012022011, + -8.0863154E-4, + 0.0070405705, + 0.050255686, + 0.06810016, + 0.017190414, + 0.051975194, + -0.051436286, + 0.023408439, + -0.029802637, + 0.034137156, + -0.004660689, + -0.0442122, + 0.019065322, + 0.030806554, + 0.0064652697, + -0.066789865, + 0.057111286, + 0.009412479, + -0.041444767, + -0.06807582, + -0.085881524, + 0.04901128, + -0.047871742, + 0.06328623, + 0.040418074, + -0.081432894, + 0.058384005, + 0.006206527, + 0.045801315, + 0.037274595, + -0.054337103, + -0.06755516, + -0.07396888, + -0.043732334, + -0.052053086, + 0.03210978, + 0.048101492, + -0.083828256, + 0.05205026, + -0.048474856, + 0.029116616, + -0.10924888, + 0.003796487, + 0.030567763, + 0.026949523, + -0.052353345, + 0.043198872, + -0.09456988, + -0.05711594, + -2.2292069E-4, + 0.032972734, + 0.054394923, + -0.0767535, + -0.02710579, + -0.032135617, + -0.01732382, + 0.059442326, + -0.07686165, + 0.07104082, + -0.03090021, + -0.05450075, + -0.038997203, + -0.07045443, + 0.00483161, + 0.010933604, + 0.020874644, + 0.037941266, + 0.019729063, + 0.06178368, + 0.013503478, + -0.008584046, + 0.045592044, + 0.05528768, + 0.11568184, + 0.0041300594, + 0.015404516, + -3.8067883E-4, + -0.06365399, + -0.07826643, + 0.061575573, + -0.060548335, + 0.05706082, + 0.042301804, + 0.052173313, + 0.07193179, + -0.03839231, + 0.0734415, + -0.045380164, + 0.02832276, + 0.003745178, + 0.058844633, + 0.04307504, + 0.037800383, + -0.031050054, + -0.06856359, + -0.059114788, + -0.02148857, + 0.07854358, + -0.03253363, + -0.04566468, + -0.019933948, + -0.057993464, + -0.08677458, + -0.06626883, + 0.031657256, + 0.101128764, + -0.08050056, + -0.050226066, + -0.014335166, + 0.050344367, + -0.06851419, + 0.008698909, + -0.011893435, + 0.07741272, + -0.059579294, + 0.03250109, + 0.058700256, + 0.046834726, + -0.035081457, + -0.0043140925, + -0.09764087, + -0.0034994273, + -0.034056358, + -0.019066337, + -0.034376107, + 0.012964423, + 0.029291175, + -0.012090671, + 0.021585712, + 0.028859599, + -0.04391145, + -0.071166754, + -0.031040335, + 0.02808108, + -0.05621317, + 0.06543945, + 0.10094665, + 0.041057374, + -0.03222324, + -0.063366964, + 0.064944476, + 0.023641933, + 0.06806713, + 0.06806097, + -0.08220105, + 0.04148528, + -0.09254079, + 0.044620737, + 0.05526614, + -0.03849534, + -0.04722273, + 0.0670776, + -0.024274077, + -0.016903497, + 0.07584147, + 0.04760533, + -0.038843267, + -0.028365409, + 0.08022705, + -0.039916333, + 0.049067073, + -0.030701574, + -0.057169467, + 0.043025102, + 0.07109674, + -0.047296863, + -0.047463104, + 0.040868305, + -0.04409507, + -0.034977127, + -0.057109762, + -0.08616165, + -0.03486079, + -0.046201482, + 0.025963873, + 0.023392359, + 0.09594902, + -0.007847159, + -0.021231368, + 0.009007263, + 0.0032713825, + -0.06876065, + 0.03169641, + -7.2582875E-4, + -0.07049708, + 0.03900843, + -0.0075472407, + 0.05184822, + 0.06452079, + -0.09832754, + -0.012775799, + -0.03925948, + -0.029761659, + 0.0065437574, + 0.0815465, + 0.0411695, + -0.0702844, + -0.009533786, + 0.07024532, + 0.0098710675, + 0.09915362, + 0.0415453, + 0.050641853, + 0.047463298, + -0.058609713, + -0.029499197, + -0.05100956, + -0.03441709, + -0.06348122, + 0.014784361, + 0.056317374, + -0.10280704, + -0.04008354, + -0.018926824, + 0.08832836, + 0.124804, + -0.047645308, + -0.07122146, + -9.886527E-4, + 0.03850324, + 0.048501793, + 0.07072816, + 0.06566776, + -0.013678872, + 0.010010848, + 0.06483413, + -0.030036367, + -0.029748922, + -0.007482364, + -0.05180385, + 0.03698522, + -0.045453787, + 0.056604166, + 0.029394176, + 0.028589265, + -0.012185886, + -0.06919616, + 0.0711641, + -0.034055933, + -0.053101335, + 0.062319, + 0.021600349, + -0.038718067, + 0.060814686, + 0.05087301, + -0.020297311, + 0.016493896, + 0.032162152, + 0.046740912, + 0.05461355, + -0.07024665, + 0.025609337, + -0.02504801, + 0.06765588, + -0.032994855, + -0.037897404, + -0.045783922, + -0.05689299, + -0.040437017, + -0.07904339, + -0.031415287, + -0.029216278, + 0.017395392, + 0.03449264, + -0.025653394, + -0.06283088, + 0.049027324, + 0.016229525, + -0.00985347, + -0.053974394, + -0.030257035, + 0.04325515, + -0.012293731, + -0.002446129, + -0.05567076, + 0.06374684, + -0.03153897, + -0.04475149, + 0.018582936, + 0.025716115, + -0.061778374, + 0.04196277, + -0.04134671, + -0.07396272, + 0.05846184, + 0.006558759, + -0.09745666, + 0.07587805, + 0.0137483915, + -0.100933895, + 0.032008193, + 0.04293283, + 0.017870268, + 0.032806385, + -0.0635923, + -0.019672254, + 0.022225974, + 0.04304554, + -0.06043949, + -0.0285274, + 0.050868835, + 0.057003833, + 0.05740866, + 0.020068677, + -0.034312245, + -0.021671802, + 0.014769731, + -0.07328285, + -0.009586734, + 0.036420938, + -0.022188472, + -0.008200541, + -0.010765854, + -0.06949713, + -0.07555878, + 0.045306854, + -0.05424466, + -0.03647476, + 0.06266633, + 0.08346125, + 0.060288202, + 0.0548457 + ], + "expected_by_score": [ + "The ancient oppidum that corresponds to the modern city of Paris was first mentioned in the mid-1st century BC by Julius Caesar as Luteciam Parisiorum ('Lutetia of the Parisii') and is later attested as Parision in the 5th century AD, then as Paris in 1265. During the Roman period, it was commonly known as Lutetia or Lutecia in Latin, and as Leukotekía in Greek, which is interpreted as either stemming from the Celtic root *lukot- ('mouse'), or from *luto- ('marsh, swamp').\n\n\nThe name Paris is derived from its early inhabitants, the Parisii, a Gallic tribe from the Iron Age and the Roman period. The meaning of the Gaulish ethnonym remains debated. According to Xavier Delamarre, it may derive from the Celtic root pario- ('cauldron'). Alfred Holder interpreted the name as 'the makers' or 'the commanders', by comparing it to the Welsh peryff ('lord, commander'), both possibly descending from a Proto-Celtic form reconstructed as *kwar-is-io-. Alternatively, Pierre-Yves Lambert proposed to translate Parisii as the 'spear people', by connecting the first element to the Old Irish carr ('spear'), derived from an earlier *kwar-sā. In any case, the city's name is not related to the Paris of Greek mythology.\n\n\nResidents of the city are known in English as Parisians and in French as Parisiens ( ⓘ). They are also pejoratively called Parigots ( ⓘ).\n\n\nHistory\n\nOrigins\n\n", + "After the marshland between the river Seine and its slower 'dead arm' to its north was filled in from around the 10th century, Paris's cultural centre began to move to the Right Bank. In 1137, a new city marketplace (today's Les Halles) replaced the two smaller ones on the Île de la Cité and Place de Grève (Place de l'Hôtel de Ville). The latter location housed the headquarters of Paris's river trade corporation, an organisation that later became, unofficially (although formally in later years), Paris's first municipal government.\n\n\nIn the late 12th century, Philip Augustus extended the Louvre fortress to defend the city against river invasions from the west, gave the city its first walls between 1190 and 1215, rebuilt its bridges to either side of its central island, and paved its main thoroughfares. In 1190, he transformed Paris's former cathedral school into a student-teacher corporation that would become the University of Paris and would draw students from all of Europe.\n\n\nWith 200,000 inhabitants in 1328, Paris, then already the capital of France, was the most populous city of Europe. By comparison, London in 1300 had 80,000 inhabitants. By the early fourteenth century, so much filth had collected inside urban Europe that French and Italian cities were naming streets after human waste. In medieval Paris, several street names were inspired by merde, the French word for \"shit\".\n\n\n", + "In March 2001, Bertrand Delanoë became the first socialist mayor. He was re-elected in March 2008. In 2007, in an effort to reduce car traffic, he introduced the Vélib', a system which rents bicycles. Bertrand Delanoë also transformed a section of the highway along the Left Bank of the Seine into an urban promenade and park, the Promenade des Berges de la Seine, which he inaugurated in June 2013.\n\n\nIn 2007, President Nicolas Sarkozy launched the Grand Paris project, to integrate Paris more closely with the towns in the region around it. After many modifications, the new area, named the Metropolis of Grand Paris, with a population of 6.7 million, was created on 1 January 2016. In 2011, the City of Paris and the national government approved the plans for the Grand Paris Express, totalling 205 km (127 mi) of automated metro lines to connect Paris, the innermost three departments around Paris, airports and high-speed rail (TGV) stations, at an estimated cost of €35 billion. The system is scheduled to be completed by 2030.\n\n\nIn January 2015, Al-Qaeda in the Arabian Peninsula claimed attacks across the Paris region. 1.5 million people marched in Paris in a show of solidarity against terrorism and in support of freedom of speech. In November of the same year, terrorist attacks, claimed by ISIL, killed 130 people and injured more than 350.\n\n\n", + "\nParis (.mw-parser-output .IPA-label-small{font-size:85%}.mw-parser-output .references .IPA-label-small,.mw-parser-output .infobox .IPA-label-small,.mw-parser-output .navbox .IPA-label-small{font-size:100%}French pronunciation: ⓘ) is the capital and largest city of France. With an estimated population of 2,102,650 residents in January 2023 in an area of more than 105 km2 (41 sq mi), Paris is the fourth-largest city in the European Union and the 30th most densely populated city in the world in 2022. Since the 17th century, Paris has been one of the world's major centres of finance, diplomacy, commerce, culture, fashion, and gastronomy. Because of its leading role in the arts and sciences and its early adaptation of extensive street lighting, it became known as the City of Light in the 19th century.\n\n\nThe City of Paris is the centre of the Île-de-France region, or Paris Region, with an official estimated population of 12,271,794 inhabitants in January 2023, or about 19% of the population of France. The Paris Region had a nominal GDP of €765 billion (US$1.064 trillion when adjusted for PPP) in 2021, the highest in the European Union. According to the Economist Intelligence Unit Worldwide Cost of Living Survey, in 2022, Paris was the city with the ninth-highest cost of living in the world.\n\n\n", + "Bal-musette is a style of French music and dance that first became popular in Paris in the 1870s and 1880s; by 1880 Paris had some 150 dance halls. Patrons danced the bourrée to the accompaniment of the cabrette (a bellows-blown bagpipe locally called a \"musette\") and often the vielle à roue (hurdy-gurdy) in the cafés and bars of the city. Parisian and Italian musicians who played the accordion adopted the style and established themselves in Auvergnat bars, and Paris became a major centre for jazz and still attracts jazz musicians from all around the world to its clubs and cafés.\n\n\nParis is the spiritual home of gypsy jazz in particular, and many of the Parisian jazzmen who developed in the first half of the 20th century began by playing Bal-musette in the city. Django Reinhardt rose to fame in Paris, having moved to the 18th arrondissement in a caravan as a young boy, and performed with violinist Stéphane Grappelli and their Quintette du Hot Club de France in the 1930s and 1940s.\n\n\nImmediately after the War the Saint-Germain-des-Pres quarter and the nearby Saint-Michel quarter became home to many small jazz clubs, including the Caveau des Lorientais, the Club Saint-Germain, the Rose Rouge, the Vieux-Colombier, and the most famous, Le Tabou. They introduced Parisians to the music of Claude Luter, Boris Vian, Sydney Bechet, Mezz Mezzrow, and Henri Salvador. " + ], + "expected_by_offset": [ + "\nParis (.mw-parser-output .IPA-label-small{font-size:85%}.mw-parser-output .references .IPA-label-small,.mw-parser-output .infobox .IPA-label-small,.mw-parser-output .navbox .IPA-label-small{font-size:100%}French pronunciation: ⓘ) is the capital and largest city of France. With an estimated population of 2,102,650 residents in January 2023 in an area of more than 105 km2 (41 sq mi), Paris is the fourth-largest city in the European Union and the 30th most densely populated city in the world in 2022. Since the 17th century, Paris has been one of the world's major centres of finance, diplomacy, commerce, culture, fashion, and gastronomy. Because of its leading role in the arts and sciences and its early adaptation of extensive street lighting, it became known as the City of Light in the 19th century.\n\n\nThe City of Paris is the centre of the Île-de-France region, or Paris Region, with an official estimated population of 12,271,794 inhabitants in January 2023, or about 19% of the population of France. The Paris Region had a nominal GDP of €765 billion (US$1.064 trillion when adjusted for PPP) in 2021, the highest in the European Union. According to the Economist Intelligence Unit Worldwide Cost of Living Survey, in 2022, Paris was the city with the ninth-highest cost of living in the world.\n\n\n", + "The ancient oppidum that corresponds to the modern city of Paris was first mentioned in the mid-1st century BC by Julius Caesar as Luteciam Parisiorum ('Lutetia of the Parisii') and is later attested as Parision in the 5th century AD, then as Paris in 1265. During the Roman period, it was commonly known as Lutetia or Lutecia in Latin, and as Leukotekía in Greek, which is interpreted as either stemming from the Celtic root *lukot- ('mouse'), or from *luto- ('marsh, swamp').\n\n\nThe name Paris is derived from its early inhabitants, the Parisii, a Gallic tribe from the Iron Age and the Roman period. The meaning of the Gaulish ethnonym remains debated. According to Xavier Delamarre, it may derive from the Celtic root pario- ('cauldron'). Alfred Holder interpreted the name as 'the makers' or 'the commanders', by comparing it to the Welsh peryff ('lord, commander'), both possibly descending from a Proto-Celtic form reconstructed as *kwar-is-io-. Alternatively, Pierre-Yves Lambert proposed to translate Parisii as the 'spear people', by connecting the first element to the Old Irish carr ('spear'), derived from an earlier *kwar-sā. In any case, the city's name is not related to the Paris of Greek mythology.\n\n\nResidents of the city are known in English as Parisians and in French as Parisiens ( ⓘ). They are also pejoratively called Parigots ( ⓘ).\n\n\nHistory\n\nOrigins\n\n", + "After the marshland between the river Seine and its slower 'dead arm' to its north was filled in from around the 10th century, Paris's cultural centre began to move to the Right Bank. In 1137, a new city marketplace (today's Les Halles) replaced the two smaller ones on the Île de la Cité and Place de Grève (Place de l'Hôtel de Ville). The latter location housed the headquarters of Paris's river trade corporation, an organisation that later became, unofficially (although formally in later years), Paris's first municipal government.\n\n\nIn the late 12th century, Philip Augustus extended the Louvre fortress to defend the city against river invasions from the west, gave the city its first walls between 1190 and 1215, rebuilt its bridges to either side of its central island, and paved its main thoroughfares. In 1190, he transformed Paris's former cathedral school into a student-teacher corporation that would become the University of Paris and would draw students from all of Europe.\n\n\nWith 200,000 inhabitants in 1328, Paris, then already the capital of France, was the most populous city of Europe. By comparison, London in 1300 had 80,000 inhabitants. By the early fourteenth century, so much filth had collected inside urban Europe that French and Italian cities were naming streets after human waste. In medieval Paris, several street names were inspired by merde, the French word for \"shit\".\n\n\n", + "In March 2001, Bertrand Delanoë became the first socialist mayor. He was re-elected in March 2008. In 2007, in an effort to reduce car traffic, he introduced the Vélib', a system which rents bicycles. Bertrand Delanoë also transformed a section of the highway along the Left Bank of the Seine into an urban promenade and park, the Promenade des Berges de la Seine, which he inaugurated in June 2013.\n\n\nIn 2007, President Nicolas Sarkozy launched the Grand Paris project, to integrate Paris more closely with the towns in the region around it. After many modifications, the new area, named the Metropolis of Grand Paris, with a population of 6.7 million, was created on 1 January 2016. In 2011, the City of Paris and the national government approved the plans for the Grand Paris Express, totalling 205 km (127 mi) of automated metro lines to connect Paris, the innermost three departments around Paris, airports and high-speed rail (TGV) stations, at an estimated cost of €35 billion. The system is scheduled to be completed by 2030.\n\n\nIn January 2015, Al-Qaeda in the Arabian Peninsula claimed attacks across the Paris region. 1.5 million people marched in Paris in a show of solidarity against terrorism and in support of freedom of speech. In November of the same year, terrorist attacks, claimed by ISIL, killed 130 people and injured more than 350.\n\n\n", + "Bal-musette is a style of French music and dance that first became popular in Paris in the 1870s and 1880s; by 1880 Paris had some 150 dance halls. Patrons danced the bourrée to the accompaniment of the cabrette (a bellows-blown bagpipe locally called a \"musette\") and often the vielle à roue (hurdy-gurdy) in the cafés and bars of the city. Parisian and Italian musicians who played the accordion adopted the style and established themselves in Auvergnat bars, and Paris became a major centre for jazz and still attracts jazz musicians from all around the world to its clubs and cafés.\n\n\nParis is the spiritual home of gypsy jazz in particular, and many of the Parisian jazzmen who developed in the first half of the 20th century began by playing Bal-musette in the city. Django Reinhardt rose to fame in Paris, having moved to the 18th arrondissement in a caravan as a young boy, and performed with violinist Stéphane Grappelli and their Quintette du Hot Club de France in the 1930s and 1940s.\n\n\nImmediately after the War the Saint-Germain-des-Pres quarter and the nearby Saint-Michel quarter became home to many small jazz clubs, including the Caveau des Lorientais, the Club Saint-Germain, the Rose Rouge, the Vieux-Colombier, and the most famous, Le Tabou. They introduced Parisians to the music of Claude Luter, Boris Vian, Sydney Bechet, Mezz Mezzrow, and Henri Salvador. " + ] + }, + "sparse_vector_1": { + "embeddings": { + "paris": 2.9709616, + "date": 2.1960778, + "founded": 2.0555024, + "foundation": 1.412623, + "early": 1.2162757, + "founder": 1.1271698, + "french": 0.9213378, + "france": 0.86253893, + "city": 0.82978916, + "founding": 0.79722786, + "established": 0.7967043, + "ancient": 0.7392465, + "when": 0.71705, + "built": 0.6977878, + "treaty": 0.6846069, + "created": 0.68127465, + "century": 0.58926934, + "for": 0.55019474, + "was": 0.52475905, + "origin": 0.48785052, + "expedition": 0.48757303, + "history": 0.47960007, + "mint": 0.47878903, + "historical": 0.4714338, + "capital": 0.42984143, + "timeline": 0.4222377, + "colony": 0.3876187, + "tower": 0.3474891, + "medieval": 0.3272666, + "geography": 0.32456368, + "colonial": 0.30613664, + "location": 0.29013386, + "francisco": 0.22840048, + "orleans": 0.21971667, + "earlier": 0.20318772, + "jackson": 0.18424438, + "exact": 0.17109296, + "rome": 0.16320735, + "civilization": 0.15931238, + "spanish": 0.12759624, + "museum": 0.113024555, + "latin": 0.11201205, + "european": 0.10277243, + "architect": 0.0796932, + "united": 0.031233707 + }, + "expected_by_score": [ + "Clovis the Frank, the first king of the Merovingian dynasty, made the city his capital from 508. As the Frankish domination of Gaul began, there was a gradual immigration by the Franks to Paris and the Parisian Francien dialects were born. Fortification of the Île de la Cité failed to avert sacking by Vikings in 845, but Paris's strategic importance—with its bridges preventing ships from passing—was established by successful defence in the Siege of Paris (885–886), for which the then Count of Paris (comte de Paris), Odo of France, was elected king of West Francia. From the Capetian dynasty that began with the 987 election of Hugh Capet, Count of Paris and Duke of the Franks (duc des Francs), as king of a unified West Francia, Paris gradually became the largest and most prosperous city in France.\n\n\nHigh and Late Middle Ages to Louis XIV\n\nBy the end of the 12th century, Paris had become the political, economic, religious, and cultural capital of France. The Palais de la Cité, the royal residence, was located at the western end of the Île de la Cité. In 1163, during the reign of Louis VII, Maurice de Sully, bishop of Paris, undertook the construction of the Notre Dame Cathedral at its eastern extremity.\n\n\nAfter the marshland between the river Seine and its slower 'dead arm' to its north was filled in from around the 10th century, Paris's cultural centre began to move to the Right Bank. ", + "\nThe Parisii, a sub-tribe of the Celtic Senones, inhabited the Paris area from around the middle of the 3rd century BC. One of the area's major north–south trade routes crossed the Seine on the Île de la Cité, which gradually became an important trading centre. The Parisii traded with many river towns (some as far away as the Iberian Peninsula) and minted their own coins.\n\n\nThe Romans conquered the Paris Basin in 52 BC and began their settlement on Paris's Left Bank. The Roman town was originally called Lutetia (more fully, Lutetia Parisiorum, \"Lutetia of the Parisii\", modern French Lutèce). It became a prosperous city with a forum, baths, temples, theatres, and an amphitheatre.\n\n\nBy the end of the Western Roman Empire, the town was known as Parisius, a Latin name that would later become Paris in French. Christianity was introduced in the middle of the 3rd century AD by Saint Denis, the first Bishop of Paris: according to legend, when he refused to renounce his faith before the Roman occupiers, he was beheaded on the hill which became known as Mons Martyrum (Latin \"Hill of Martyrs\"), later \"Montmartre\", from where he walked headless to the north of the city; the place where he fell and was buried became an important religious shrine, the Basilica of Saint-Denis, and many French kings are buried there.\n\n\nClovis the Frank, the first king of the Merovingian dynasty, made the city his capital from 508. ", + "\nDuring the Hundred Years' War, Paris was occupied by England-friendly Burgundian forces from 1418, before being occupied outright by the English when Henry V of England entered the French capital in 1420; in spite of a 1429 effort by Joan of Arc to liberate the city, it would remain under English occupation until 1436.\n\n\nIn the late 16th-century French Wars of Religion, Paris was a stronghold of the Catholic League, the organisers of 24 August 1572 St. Bartholomew's Day massacre in which thousands of French Protestants were killed. The conflicts ended when pretender to the throne Henry IV, after converting to Catholicism to gain entry to the capital, entered the city in 1594 to claim the crown of France. This king made several improvements to the capital during his reign: he completed the construction of Paris's first uncovered, sidewalk-lined bridge, the Pont Neuf, built a Louvre extension connecting it to the Tuileries Palace, and created the first Paris residential square, the Place Royale, now Place des Vosges. In spite of Henry IV's efforts to improve city circulation, the narrowness of Paris's streets was a contributing factor in his assassination near Les Halles marketplace in 1610.\n\n\nDuring the 17th century, Cardinal Richelieu, chief minister of Louis XIII, was determined to make Paris the most beautiful city in Europe. He built five new bridges, a new chapel for the College of Sorbonne, and a palace for himself, the Palais-Cardinal. ", + "Diderot and D'Alembert published their Encyclopédie in 1751, before the Montgolfier Brothers launched the first manned flight in a hot air balloon on 21 November 1783. Paris was the financial capital of continental Europe, as well the primary European centre for book publishing, fashion and the manufacture of fine furniture and luxury goods. On 22 October 1797, Paris was also the site of the first parachute jump in history, by Garnerin.\n\n\nIn the summer of 1789, Paris became the centre stage of the French Revolution. On 14 July, a mob seized the arsenal at the Invalides, acquiring thousands of guns, with which it stormed the Bastille, a principal symbol of royal authority. The first independent Paris Commune, or city council, met in the Hôtel de Ville and elected a Mayor, the astronomer Jean Sylvain Bailly, on 15 July.\n\n\nLouis XVI and the royal family were brought to Paris and incarcerated in the Tuileries Palace. In 1793, as the revolution turned increasingly radical, the king, queen and mayor were beheaded by guillotine in the Reign of Terror, along with more than 16,000 others throughout France. The property of the aristocracy and the church was nationalised, and the city's churches were closed, sold or demolished. A succession of revolutionary factions ruled Paris until 9 November 1799 (coup d'état du 18 brumaire), when Napoleon Bonaparte seized power as First Consul.\n\n\n", + "After the marshland between the river Seine and its slower 'dead arm' to its north was filled in from around the 10th century, Paris's cultural centre began to move to the Right Bank. In 1137, a new city marketplace (today's Les Halles) replaced the two smaller ones on the Île de la Cité and Place de Grève (Place de l'Hôtel de Ville). The latter location housed the headquarters of Paris's river trade corporation, an organisation that later became, unofficially (although formally in later years), Paris's first municipal government.\n\n\nIn the late 12th century, Philip Augustus extended the Louvre fortress to defend the city against river invasions from the west, gave the city its first walls between 1190 and 1215, rebuilt its bridges to either side of its central island, and paved its main thoroughfares. In 1190, he transformed Paris's former cathedral school into a student-teacher corporation that would become the University of Paris and would draw students from all of Europe.\n\n\nWith 200,000 inhabitants in 1328, Paris, then already the capital of France, was the most populous city of Europe. By comparison, London in 1300 had 80,000 inhabitants. By the early fourteenth century, so much filth had collected inside urban Europe that French and Italian cities were naming streets after human waste. In medieval Paris, several street names were inspired by merde, the French word for \"shit\".\n\n\n" + ], + "expected_by_offset": [ + "\nThe Parisii, a sub-tribe of the Celtic Senones, inhabited the Paris area from around the middle of the 3rd century BC. One of the area's major north–south trade routes crossed the Seine on the Île de la Cité, which gradually became an important trading centre. The Parisii traded with many river towns (some as far away as the Iberian Peninsula) and minted their own coins.\n\n\nThe Romans conquered the Paris Basin in 52 BC and began their settlement on Paris's Left Bank. The Roman town was originally called Lutetia (more fully, Lutetia Parisiorum, \"Lutetia of the Parisii\", modern French Lutèce). It became a prosperous city with a forum, baths, temples, theatres, and an amphitheatre.\n\n\nBy the end of the Western Roman Empire, the town was known as Parisius, a Latin name that would later become Paris in French. Christianity was introduced in the middle of the 3rd century AD by Saint Denis, the first Bishop of Paris: according to legend, when he refused to renounce his faith before the Roman occupiers, he was beheaded on the hill which became known as Mons Martyrum (Latin \"Hill of Martyrs\"), later \"Montmartre\", from where he walked headless to the north of the city; the place where he fell and was buried became an important religious shrine, the Basilica of Saint-Denis, and many French kings are buried there.\n\n\nClovis the Frank, the first king of the Merovingian dynasty, made the city his capital from 508. ", + "Clovis the Frank, the first king of the Merovingian dynasty, made the city his capital from 508. As the Frankish domination of Gaul began, there was a gradual immigration by the Franks to Paris and the Parisian Francien dialects were born. Fortification of the Île de la Cité failed to avert sacking by Vikings in 845, but Paris's strategic importance—with its bridges preventing ships from passing—was established by successful defence in the Siege of Paris (885–886), for which the then Count of Paris (comte de Paris), Odo of France, was elected king of West Francia. From the Capetian dynasty that began with the 987 election of Hugh Capet, Count of Paris and Duke of the Franks (duc des Francs), as king of a unified West Francia, Paris gradually became the largest and most prosperous city in France.\n\n\nHigh and Late Middle Ages to Louis XIV\n\nBy the end of the 12th century, Paris had become the political, economic, religious, and cultural capital of France. The Palais de la Cité, the royal residence, was located at the western end of the Île de la Cité. In 1163, during the reign of Louis VII, Maurice de Sully, bishop of Paris, undertook the construction of the Notre Dame Cathedral at its eastern extremity.\n\n\nAfter the marshland between the river Seine and its slower 'dead arm' to its north was filled in from around the 10th century, Paris's cultural centre began to move to the Right Bank. ", + "After the marshland between the river Seine and its slower 'dead arm' to its north was filled in from around the 10th century, Paris's cultural centre began to move to the Right Bank. In 1137, a new city marketplace (today's Les Halles) replaced the two smaller ones on the Île de la Cité and Place de Grève (Place de l'Hôtel de Ville). The latter location housed the headquarters of Paris's river trade corporation, an organisation that later became, unofficially (although formally in later years), Paris's first municipal government.\n\n\nIn the late 12th century, Philip Augustus extended the Louvre fortress to defend the city against river invasions from the west, gave the city its first walls between 1190 and 1215, rebuilt its bridges to either side of its central island, and paved its main thoroughfares. In 1190, he transformed Paris's former cathedral school into a student-teacher corporation that would become the University of Paris and would draw students from all of Europe.\n\n\nWith 200,000 inhabitants in 1328, Paris, then already the capital of France, was the most populous city of Europe. By comparison, London in 1300 had 80,000 inhabitants. By the early fourteenth century, so much filth had collected inside urban Europe that French and Italian cities were naming streets after human waste. In medieval Paris, several street names were inspired by merde, the French word for \"shit\".\n\n\n", + "\nDuring the Hundred Years' War, Paris was occupied by England-friendly Burgundian forces from 1418, before being occupied outright by the English when Henry V of England entered the French capital in 1420; in spite of a 1429 effort by Joan of Arc to liberate the city, it would remain under English occupation until 1436.\n\n\nIn the late 16th-century French Wars of Religion, Paris was a stronghold of the Catholic League, the organisers of 24 August 1572 St. Bartholomew's Day massacre in which thousands of French Protestants were killed. The conflicts ended when pretender to the throne Henry IV, after converting to Catholicism to gain entry to the capital, entered the city in 1594 to claim the crown of France. This king made several improvements to the capital during his reign: he completed the construction of Paris's first uncovered, sidewalk-lined bridge, the Pont Neuf, built a Louvre extension connecting it to the Tuileries Palace, and created the first Paris residential square, the Place Royale, now Place des Vosges. In spite of Henry IV's efforts to improve city circulation, the narrowness of Paris's streets was a contributing factor in his assassination near Les Halles marketplace in 1610.\n\n\nDuring the 17th century, Cardinal Richelieu, chief minister of Louis XIII, was determined to make Paris the most beautiful city in Europe. He built five new bridges, a new chapel for the College of Sorbonne, and a palace for himself, the Palais-Cardinal. ", + "Diderot and D'Alembert published their Encyclopédie in 1751, before the Montgolfier Brothers launched the first manned flight in a hot air balloon on 21 November 1783. Paris was the financial capital of continental Europe, as well the primary European centre for book publishing, fashion and the manufacture of fine furniture and luxury goods. On 22 October 1797, Paris was also the site of the first parachute jump in history, by Garnerin.\n\n\nIn the summer of 1789, Paris became the centre stage of the French Revolution. On 14 July, a mob seized the arsenal at the Invalides, acquiring thousands of guns, with which it stormed the Bastille, a principal symbol of royal authority. The first independent Paris Commune, or city council, met in the Hôtel de Ville and elected a Mayor, the astronomer Jean Sylvain Bailly, on 15 July.\n\n\nLouis XVI and the royal family were brought to Paris and incarcerated in the Tuileries Palace. In 1793, as the revolution turned increasingly radical, the king, queen and mayor were beheaded by guillotine in the Reign of Terror, along with more than 16,000 others throughout France. The property of the aristocracy and the church was nationalised, and the city's churches were closed, sold or demolished. A succession of revolutionary factions ruled Paris until 9 November 1799 (coup d'état du 18 brumaire), when Napoleon Bonaparte seized power as First Consul.\n\n\n" + ] + } +} \ No newline at end of file diff --git a/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/sample-doc.json.gz b/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/sample-doc.json.gz new file mode 100644 index 000000000000..881524e46e18 Binary files /dev/null and b/x-pack/plugin/inference/src/test/resources/org/elasticsearch/xpack/inference/highlight/sample-doc.json.gz differ diff --git a/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter.yml b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter.yml new file mode 100644 index 000000000000..25cd1b5aec48 --- /dev/null +++ b/x-pack/plugin/inference/src/yamlRestTest/resources/rest-api-spec/test/inference/90_semantic_text_highlighter.yml @@ -0,0 +1,242 @@ +setup: + - requires: + cluster_features: "semantic_text.highlighter" + reason: a new highlighter for semantic text field + + - do: + inference.put: + task_type: sparse_embedding + inference_id: sparse-inference-id + body: > + { + "service": "test_service", + "service_settings": { + "model": "my_model", + "api_key": "abc64" + }, + "task_settings": { + } + } + + - do: + inference.put: + task_type: text_embedding + inference_id: dense-inference-id + body: > + { + "service": "text_embedding_test_service", + "service_settings": { + "model": "my_model", + "dimensions": 10, + "api_key": "abc64", + "similarity": "COSINE" + }, + "task_settings": { + } + } + + - do: + indices.create: + index: test-sparse-index + body: + mappings: + properties: + body: + type: semantic_text + inference_id: sparse-inference-id + + - do: + indices.create: + index: test-dense-index + body: + mappings: + properties: + body: + type: semantic_text + inference_id: dense-inference-id + +--- +"Highlighting using a sparse embedding model": + - do: + index: + index: test-sparse-index + id: doc_1 + body: + body: ["ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!"] + refresh: true + + - match: { result: created } + + - do: + search: + index: test-sparse-index + body: + query: + semantic: + field: "body" + query: "What is Elasticsearch?" + highlight: + fields: + body: + type: "semantic" + number_of_fragments: 1 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.body: 1 } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + - do: + search: + index: test-sparse-index + body: + query: + semantic: + field: "body" + query: "What is Elasticsearch?" + highlight: + fields: + body: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.body: 2 } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + + - do: + search: + index: test-sparse-index + body: + query: + semantic: + field: "body" + query: "What is Elasticsearch?" + highlight: + fields: + body: + type: "semantic" + order: "score" + number_of_fragments: 1 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.body: 1 } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + - do: + search: + index: test-sparse-index + body: + query: + semantic: + field: "body" + query: "What is Elasticsearch?" + highlight: + fields: + body: + type: "semantic" + order: "score" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.body: 2 } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + +--- +"Highlighting using a dense embedding model": + - do: + index: + index: test-dense-index + id: doc_1 + body: + body: ["ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides.", "You Know, for Search!"] + refresh: true + + - match: { result: created } + + - do: + search: + index: test-dense-index + body: + query: + semantic: + field: "body" + query: "What is Elasticsearch?" + highlight: + fields: + body: + type: "semantic" + number_of_fragments: 1 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.body: 1 } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + + - do: + search: + index: test-dense-index + body: + query: + semantic: + field: "body" + query: "What is Elasticsearch?" + highlight: + fields: + body: + type: "semantic" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.body: 2 } + - match: { hits.hits.0.highlight.body.0: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + - match: { hits.hits.0.highlight.body.1: "You Know, for Search!" } + + - do: + search: + index: test-dense-index + body: + query: + semantic: + field: "body" + query: "What is Elasticsearch?" + highlight: + fields: + body: + type: "semantic" + order: "score" + number_of_fragments: 1 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.body: 1 } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + + - do: + search: + index: test-dense-index + body: + query: + semantic: + field: "body" + query: "What is Elasticsearch?" + highlight: + fields: + body: + type: "semantic" + order: "score" + number_of_fragments: 2 + + - match: { hits.total.value: 1 } + - match: { hits.hits.0._id: "doc_1" } + - length: { hits.hits.0.highlight.body: 2 } + - match: { hits.hits.0.highlight.body.0: "You Know, for Search!" } + - match: { hits.hits.0.highlight.body.1: "ElasticSearch is an open source, distributed, RESTful, search engine which is built on top of Lucene internally and enjoys all the features it provides." } + + diff --git a/x-pack/plugin/logsdb/src/test/java/org/elasticsearch/xpack/logsdb/SyntheticSourceLicenseServiceTests.java b/x-pack/plugin/logsdb/src/test/java/org/elasticsearch/xpack/logsdb/SyntheticSourceLicenseServiceTests.java index 90a13b16c028..0eb0d21ff2e7 100644 --- a/x-pack/plugin/logsdb/src/test/java/org/elasticsearch/xpack/logsdb/SyntheticSourceLicenseServiceTests.java +++ b/x-pack/plugin/logsdb/src/test/java/org/elasticsearch/xpack/logsdb/SyntheticSourceLicenseServiceTests.java @@ -41,6 +41,7 @@ public void setup() throws Exception { public void testLicenseAllowsSyntheticSource() { MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState); when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE))).thenReturn(true); licenseService.setLicenseState(licenseState); licenseService.setLicenseService(mockLicenseService); @@ -53,6 +54,7 @@ public void testLicenseAllowsSyntheticSource() { public void testLicenseAllowsSyntheticSourceTemplateValidation() { MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState); when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE))).thenReturn(true); licenseService.setLicenseState(licenseState); licenseService.setLicenseService(mockLicenseService); @@ -65,6 +67,7 @@ public void testLicenseAllowsSyntheticSourceTemplateValidation() { public void testDefaultDisallow() { MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState); when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE))).thenReturn(false); licenseService.setLicenseState(licenseState); licenseService.setLicenseService(mockLicenseService); @@ -77,6 +80,7 @@ public void testDefaultDisallow() { public void testFallback() { MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState); when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE))).thenReturn(true); licenseService.setLicenseState(licenseState); licenseService.setLicenseService(mockLicenseService); @@ -95,6 +99,7 @@ public void testGoldOrPlatinumLicense() throws Exception { when(mockLicenseService.getLicense()).thenReturn(license); MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState); when(licenseState.getOperationMode()).thenReturn(license.operationMode()); when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE_LEGACY))).thenReturn(true); licenseService.setLicenseState(licenseState); @@ -103,6 +108,8 @@ public void testGoldOrPlatinumLicense() throws Exception { "legacy licensed usage is allowed, so not fallback to stored source", licenseService.fallbackToStoredSource(false, true) ); + Mockito.verify(licenseState, Mockito.times(1)).isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE)); + Mockito.verify(licenseState, Mockito.times(1)).isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE_LEGACY)); Mockito.verify(licenseState, Mockito.times(1)).featureUsed(any()); } @@ -112,6 +119,7 @@ public void testGoldOrPlatinumLicenseLegacyLicenseNotAllowed() throws Exception when(mockLicenseService.getLicense()).thenReturn(license); MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState); when(licenseState.getOperationMode()).thenReturn(license.operationMode()); when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE))).thenReturn(false); licenseService.setLicenseState(licenseState); @@ -125,14 +133,16 @@ public void testGoldOrPlatinumLicenseLegacyLicenseNotAllowed() throws Exception } public void testGoldOrPlatinumLicenseBeyondCutoffDate() throws Exception { - long start = LocalDateTime.of(2025, 1, 1, 0, 0).toInstant(ZoneOffset.UTC).toEpochMilli(); + long start = LocalDateTime.of(2025, 2, 5, 0, 0).toInstant(ZoneOffset.UTC).toEpochMilli(); License license = createGoldOrPlatinumLicense(start); mockLicenseService = mock(LicenseService.class); when(mockLicenseService.getLicense()).thenReturn(license); MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState); when(licenseState.getOperationMode()).thenReturn(license.operationMode()); when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE))).thenReturn(false); + when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE_LEGACY))).thenReturn(true); licenseService.setLicenseState(licenseState); licenseService.setLicenseService(mockLicenseService); assertTrue("beyond cutoff date, so fallback to stored source", licenseService.fallbackToStoredSource(false, true)); @@ -143,19 +153,21 @@ public void testGoldOrPlatinumLicenseBeyondCutoffDate() throws Exception { public void testGoldOrPlatinumLicenseCustomCutoffDate() throws Exception { licenseService = new SyntheticSourceLicenseService(Settings.EMPTY, "2025-01-02T00:00"); - long start = LocalDateTime.of(2025, 1, 1, 0, 0).toInstant(ZoneOffset.UTC).toEpochMilli(); + long start = LocalDateTime.of(2025, 1, 3, 0, 0).toInstant(ZoneOffset.UTC).toEpochMilli(); License license = createGoldOrPlatinumLicense(start); mockLicenseService = mock(LicenseService.class); when(mockLicenseService.getLicense()).thenReturn(license); MockLicenseState licenseState = MockLicenseState.createMock(); + when(licenseState.copyCurrentLicenseState()).thenReturn(licenseState); when(licenseState.getOperationMode()).thenReturn(license.operationMode()); + when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE))).thenReturn(false); when(licenseState.isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE_LEGACY))).thenReturn(true); licenseService.setLicenseState(licenseState); licenseService.setLicenseService(mockLicenseService); - assertFalse("custom cutoff date, so fallback to stored source", licenseService.fallbackToStoredSource(false, true)); - Mockito.verify(licenseState, Mockito.times(1)).featureUsed(any()); - Mockito.verify(licenseState, Mockito.times(1)).isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE_LEGACY)); + assertTrue("custom cutoff date, so fallback to stored source", licenseService.fallbackToStoredSource(false, true)); + Mockito.verify(licenseState, Mockito.times(1)).isAllowed(same(SyntheticSourceLicenseService.SYNTHETIC_SOURCE_FEATURE)); + Mockito.verify(licenseState, Mockito.never()).featureUsed(any()); } static License createEnterpriseLicense() throws Exception { diff --git a/x-pack/plugin/migrate/src/internalClusterTest/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportActionIT.java b/x-pack/plugin/migrate/src/internalClusterTest/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportActionIT.java index 3b68fc9995b5..62716e11f172 100644 --- a/x-pack/plugin/migrate/src/internalClusterTest/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportActionIT.java +++ b/x-pack/plugin/migrate/src/internalClusterTest/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamTransportActionIT.java @@ -51,7 +51,10 @@ protected Collection> nodePlugins() { public void testNonExistentDataStream() { String nonExistentDataStreamName = randomAlphaOfLength(50); - ReindexDataStreamRequest reindexDataStreamRequest = new ReindexDataStreamRequest(nonExistentDataStreamName); + ReindexDataStreamRequest reindexDataStreamRequest = new ReindexDataStreamRequest( + ReindexDataStreamAction.Mode.UPGRADE, + nonExistentDataStreamName + ); assertThrows( ResourceNotFoundException.class, () -> client().execute(new ActionType(ReindexDataStreamAction.NAME), reindexDataStreamRequest) @@ -61,7 +64,10 @@ public void testNonExistentDataStream() { public void testAlreadyUpToDateDataStream() throws Exception { String dataStreamName = randomAlphaOfLength(50).toLowerCase(Locale.ROOT); - ReindexDataStreamRequest reindexDataStreamRequest = new ReindexDataStreamRequest(dataStreamName); + ReindexDataStreamRequest reindexDataStreamRequest = new ReindexDataStreamRequest( + ReindexDataStreamAction.Mode.UPGRADE, + dataStreamName + ); createDataStream(dataStreamName); ReindexDataStreamResponse response = client().execute( new ActionType(ReindexDataStreamAction.NAME), diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/MigratePlugin.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/MigratePlugin.java index 118cd69ece4d..ac9e38da0742 100644 --- a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/MigratePlugin.java +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/MigratePlugin.java @@ -11,21 +11,30 @@ import org.elasticsearch.action.ActionResponse; import org.elasticsearch.client.internal.Client; import org.elasticsearch.cluster.metadata.IndexNameExpressionResolver; +import org.elasticsearch.cluster.node.DiscoveryNodes; import org.elasticsearch.cluster.service.ClusterService; import org.elasticsearch.common.io.stream.NamedWriteableRegistry; +import org.elasticsearch.common.settings.ClusterSettings; +import org.elasticsearch.common.settings.IndexScopedSettings; +import org.elasticsearch.common.settings.Settings; +import org.elasticsearch.common.settings.SettingsFilter; import org.elasticsearch.common.settings.SettingsModule; +import org.elasticsearch.features.NodeFeature; import org.elasticsearch.persistent.PersistentTaskParams; import org.elasticsearch.persistent.PersistentTaskState; import org.elasticsearch.persistent.PersistentTasksExecutor; import org.elasticsearch.plugins.ActionPlugin; import org.elasticsearch.plugins.PersistentTaskPlugin; import org.elasticsearch.plugins.Plugin; +import org.elasticsearch.rest.RestController; +import org.elasticsearch.rest.RestHandler; import org.elasticsearch.tasks.Task; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.xcontent.NamedXContentRegistry; import org.elasticsearch.xcontent.ParseField; import org.elasticsearch.xpack.migrate.action.ReindexDataStreamAction; import org.elasticsearch.xpack.migrate.action.ReindexDataStreamTransportAction; +import org.elasticsearch.xpack.migrate.rest.RestMigrationReindexAction; import org.elasticsearch.xpack.migrate.task.ReindexDataStreamPersistentTaskExecutor; import org.elasticsearch.xpack.migrate.task.ReindexDataStreamPersistentTaskState; import org.elasticsearch.xpack.migrate.task.ReindexDataStreamStatus; @@ -34,47 +43,80 @@ import java.util.ArrayList; import java.util.List; +import java.util.function.Predicate; +import java.util.function.Supplier; + +import static org.elasticsearch.xpack.migrate.action.ReindexDataStreamAction.REINDEX_DATA_STREAM_FEATURE_FLAG; public class MigratePlugin extends Plugin implements ActionPlugin, PersistentTaskPlugin { + @Override + public List getRestHandlers( + Settings unused, + NamedWriteableRegistry namedWriteableRegistry, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster, + Predicate clusterSupportsFeature + ) { + List handlers = new ArrayList<>(); + if (REINDEX_DATA_STREAM_FEATURE_FLAG.isEnabled()) { + handlers.add(new RestMigrationReindexAction()); + } + return handlers; + } + @Override public List> getActions() { List> actions = new ArrayList<>(); - actions.add(new ActionHandler<>(ReindexDataStreamAction.INSTANCE, ReindexDataStreamTransportAction.class)); + if (REINDEX_DATA_STREAM_FEATURE_FLAG.isEnabled()) { + actions.add(new ActionHandler<>(ReindexDataStreamAction.INSTANCE, ReindexDataStreamTransportAction.class)); + } return actions; } @Override public List getNamedXContent() { - return List.of( - new NamedXContentRegistry.Entry( - PersistentTaskState.class, - new ParseField(ReindexDataStreamPersistentTaskState.NAME), - ReindexDataStreamPersistentTaskState::fromXContent - ), - new NamedXContentRegistry.Entry( - PersistentTaskParams.class, - new ParseField(ReindexDataStreamTaskParams.NAME), - ReindexDataStreamTaskParams::fromXContent - ) - ); + if (REINDEX_DATA_STREAM_FEATURE_FLAG.isEnabled()) { + return List.of( + new NamedXContentRegistry.Entry( + PersistentTaskState.class, + new ParseField(ReindexDataStreamPersistentTaskState.NAME), + ReindexDataStreamPersistentTaskState::fromXContent + ), + new NamedXContentRegistry.Entry( + PersistentTaskParams.class, + new ParseField(ReindexDataStreamTaskParams.NAME), + ReindexDataStreamTaskParams::fromXContent + ) + ); + } else { + return List.of(); + } } @Override public List getNamedWriteables() { - return List.of( - new NamedWriteableRegistry.Entry( - PersistentTaskState.class, - ReindexDataStreamPersistentTaskState.NAME, - ReindexDataStreamPersistentTaskState::new - ), - new NamedWriteableRegistry.Entry( - PersistentTaskParams.class, - ReindexDataStreamTaskParams.NAME, - ReindexDataStreamTaskParams::new - ), - new NamedWriteableRegistry.Entry(Task.Status.class, ReindexDataStreamStatus.NAME, ReindexDataStreamStatus::new) - ); + if (REINDEX_DATA_STREAM_FEATURE_FLAG.isEnabled()) { + return List.of( + new NamedWriteableRegistry.Entry( + PersistentTaskState.class, + ReindexDataStreamPersistentTaskState.NAME, + ReindexDataStreamPersistentTaskState::new + ), + new NamedWriteableRegistry.Entry( + PersistentTaskParams.class, + ReindexDataStreamTaskParams.NAME, + ReindexDataStreamTaskParams::new + ), + new NamedWriteableRegistry.Entry(Task.Status.class, ReindexDataStreamStatus.NAME, ReindexDataStreamStatus::new) + ); + } else { + return List.of(); + } } @Override @@ -85,6 +127,12 @@ public List> getPersistentTasksExecutor( SettingsModule settingsModule, IndexNameExpressionResolver expressionResolver ) { - return List.of(new ReindexDataStreamPersistentTaskExecutor(client, clusterService, ReindexDataStreamTask.TASK_NAME, threadPool)); + if (REINDEX_DATA_STREAM_FEATURE_FLAG.isEnabled()) { + return List.of( + new ReindexDataStreamPersistentTaskExecutor(client, clusterService, ReindexDataStreamTask.TASK_NAME, threadPool) + ); + } else { + return List.of(); + } } } diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamAction.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamAction.java index 1785e6971f82..eb7a910df8c0 100644 --- a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamAction.java +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamAction.java @@ -11,23 +11,41 @@ import org.elasticsearch.action.ActionRequestValidationException; import org.elasticsearch.action.ActionResponse; import org.elasticsearch.action.ActionType; +import org.elasticsearch.action.IndicesRequest; +import org.elasticsearch.action.support.IndicesOptions; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; +import org.elasticsearch.common.util.FeatureFlag; +import org.elasticsearch.features.NodeFeature; +import org.elasticsearch.xcontent.ConstructingObjectParser; +import org.elasticsearch.xcontent.ParseField; +import org.elasticsearch.xcontent.ToXContent; import org.elasticsearch.xcontent.ToXContentObject; import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; import java.io.IOException; +import java.util.Locale; import java.util.Objects; +import java.util.function.Predicate; public class ReindexDataStreamAction extends ActionType { + public static final FeatureFlag REINDEX_DATA_STREAM_FEATURE_FLAG = new FeatureFlag("reindex_data_stream"); public static final ReindexDataStreamAction INSTANCE = new ReindexDataStreamAction(); public static final String NAME = "indices:admin/data_stream/reindex"; + public static final ParseField MODE_FIELD = new ParseField("mode"); + public static final ParseField SOURCE_FIELD = new ParseField("source"); + public static final ParseField INDEX_FIELD = new ParseField("index"); public ReindexDataStreamAction() { super(NAME); } + public enum Mode { + UPGRADE + } + public static class ReindexDataStreamResponse extends ActionResponse implements ToXContentObject { private final String taskId; @@ -49,7 +67,7 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject(); - builder.field("task", getTaskId()); + builder.field("acknowledged", true); builder.endObject(); return builder; } @@ -70,22 +88,52 @@ public boolean equals(Object other) { } - public static class ReindexDataStreamRequest extends ActionRequest { + public static class ReindexDataStreamRequest extends ActionRequest implements IndicesRequest, ToXContent { + private final Mode mode; private final String sourceDataStream; - public ReindexDataStreamRequest(String sourceDataStream) { - super(); + public ReindexDataStreamRequest(Mode mode, String sourceDataStream) { + this.mode = mode; this.sourceDataStream = sourceDataStream; } public ReindexDataStreamRequest(StreamInput in) throws IOException { super(in); + this.mode = Mode.valueOf(in.readString()); this.sourceDataStream = in.readString(); } + private static final ConstructingObjectParser> PARSER = + new ConstructingObjectParser<>("migration_reindex", objects -> { + Mode mode = Mode.valueOf(((String) objects[0]).toUpperCase(Locale.ROOT)); + String source = (String) objects[1]; + return new ReindexDataStreamRequest(mode, source); + }); + + private static final ConstructingObjectParser SOURCE_PARSER = new ConstructingObjectParser<>( + SOURCE_FIELD.getPreferredName(), + false, + (a, id) -> (String) a[0] + ); + + static { + SOURCE_PARSER.declareString(ConstructingObjectParser.constructorArg(), INDEX_FIELD); + PARSER.declareString(ConstructingObjectParser.constructorArg(), MODE_FIELD); + PARSER.declareObject( + ConstructingObjectParser.constructorArg(), + (parser, id) -> SOURCE_PARSER.apply(parser, null), + SOURCE_FIELD + ); + } + + public static ReindexDataStreamRequest fromXContent(XContentParser parser) { + return PARSER.apply(parser, null); + } + @Override public void writeTo(StreamOutput out) throws IOException { super.writeTo(out); + out.writeString(mode.name()); out.writeString(sourceDataStream); } @@ -103,15 +151,42 @@ public String getSourceDataStream() { return sourceDataStream; } + public Mode getMode() { + return mode; + } + @Override public int hashCode() { - return Objects.hashCode(sourceDataStream); + return Objects.hash(mode, sourceDataStream); } @Override public boolean equals(Object other) { - return other instanceof ReindexDataStreamRequest - && sourceDataStream.equals(((ReindexDataStreamRequest) other).sourceDataStream); + return other instanceof ReindexDataStreamRequest otherRequest + && mode.equals(otherRequest.mode) + && sourceDataStream.equals(otherRequest.sourceDataStream); + } + + @Override + public String[] indices() { + return new String[] { sourceDataStream }; + } + + @Override + public IndicesOptions indicesOptions() { + return IndicesOptions.strictSingleIndexNoExpandForbidClosed(); + } + + /* + * This only exists for the sake of testing the xcontent parser + */ + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.field(MODE_FIELD.getPreferredName(), mode); + builder.startObject(SOURCE_FIELD.getPreferredName()); + builder.field(INDEX_FIELD.getPreferredName(), sourceDataStream); + builder.endObject(); + return builder; } } } diff --git a/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/rest/RestMigrationReindexAction.java b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/rest/RestMigrationReindexAction.java new file mode 100644 index 000000000000..a7f630d68234 --- /dev/null +++ b/x-pack/plugin/migrate/src/main/java/org/elasticsearch/xpack/migrate/rest/RestMigrationReindexAction.java @@ -0,0 +1,64 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.migrate.rest; + +import org.elasticsearch.client.internal.node.NodeClient; +import org.elasticsearch.rest.BaseRestHandler; +import org.elasticsearch.rest.RestChannel; +import org.elasticsearch.rest.RestRequest; +import org.elasticsearch.rest.RestResponse; +import org.elasticsearch.rest.RestStatus; +import org.elasticsearch.rest.action.RestBuilderListener; +import org.elasticsearch.xcontent.XContentBuilder; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.migrate.action.ReindexDataStreamAction; +import org.elasticsearch.xpack.migrate.action.ReindexDataStreamAction.ReindexDataStreamResponse; + +import java.io.IOException; +import java.util.List; + +import static org.elasticsearch.rest.RestRequest.Method.POST; + +public class RestMigrationReindexAction extends BaseRestHandler { + + @Override + public String getName() { + return "migration_reindex"; + } + + @Override + public List routes() { + return List.of(new Route(POST, "/_migration/reindex")); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) throws IOException { + ReindexDataStreamAction.ReindexDataStreamRequest reindexRequest; + try (XContentParser parser = request.contentParser()) { + reindexRequest = ReindexDataStreamAction.ReindexDataStreamRequest.fromXContent(parser); + } + return channel -> client.execute( + ReindexDataStreamAction.INSTANCE, + reindexRequest, + new ReindexDataStreamRestToXContentListener(channel) + ); + } + + static class ReindexDataStreamRestToXContentListener extends RestBuilderListener { + + ReindexDataStreamRestToXContentListener(RestChannel channel) { + super(channel); + } + + @Override + public RestResponse buildResponse(ReindexDataStreamResponse response, XContentBuilder builder) throws Exception { + response.toXContent(builder, channel.request()); + return new RestResponse(RestStatus.OK, builder); + } + } +} diff --git a/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamRequestTests.java b/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamRequestTests.java new file mode 100644 index 000000000000..9c7bf87b6cff --- /dev/null +++ b/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamRequestTests.java @@ -0,0 +1,39 @@ +/* + * Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one + * or more contributor license agreements. Licensed under the Elastic License + * 2.0; you may not use this file except in compliance with the Elastic License + * 2.0. + */ + +package org.elasticsearch.xpack.migrate.action; + +import org.elasticsearch.common.io.stream.Writeable; +import org.elasticsearch.test.AbstractXContentSerializingTestCase; +import org.elasticsearch.xcontent.XContentParser; +import org.elasticsearch.xpack.migrate.action.ReindexDataStreamAction.ReindexDataStreamRequest; + +import java.io.IOException; + +public class ReindexDataStreamRequestTests extends AbstractXContentSerializingTestCase { + + @Override + protected ReindexDataStreamRequest createTestInstance() { + return new ReindexDataStreamRequest(ReindexDataStreamAction.Mode.UPGRADE, randomAlphaOfLength(40)); + } + + @Override + protected ReindexDataStreamRequest mutateInstance(ReindexDataStreamRequest instance) { + // There is currently only one possible value for mode, so we can't change it + return new ReindexDataStreamRequest(instance.getMode(), randomAlphaOfLength(50)); + } + + @Override + protected ReindexDataStreamRequest doParseInstance(XContentParser parser) throws IOException { + return ReindexDataStreamRequest.fromXContent(parser); + } + + @Override + protected Writeable.Reader instanceReader() { + return ReindexDataStreamRequest::new; + } +} diff --git a/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamResponseTests.java b/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamResponseTests.java index 06844577c4e3..d886fc660d7a 100644 --- a/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamResponseTests.java +++ b/x-pack/plugin/migrate/src/test/java/org/elasticsearch/xpack/migrate/action/ReindexDataStreamResponseTests.java @@ -43,7 +43,7 @@ public void testToXContent() throws IOException { builder.humanReadable(true); response.toXContent(builder, EMPTY_PARAMS); try (XContentParser parser = createParser(JsonXContent.jsonXContent, BytesReference.bytes(builder))) { - assertThat(parser.map(), equalTo(Map.of("task", response.getTaskId()))); + assertThat(parser.map(), equalTo(Map.of("acknowledged", true))); } } } diff --git a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java index 8df10037affd..c91314716cf9 100644 --- a/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java +++ b/x-pack/plugin/security/qa/operator-privileges-tests/src/javaRestTest/java/org/elasticsearch/xpack/security/operator/Constants.java @@ -386,6 +386,7 @@ public class Constants { "cluster:monitor/xpack/esql/stats/dist", "cluster:monitor/xpack/inference", "cluster:monitor/xpack/inference/get", + "cluster:monitor/xpack/inference/unified", "cluster:monitor/xpack/inference/diagnostics/get", "cluster:monitor/xpack/inference/services/get", "cluster:monitor/xpack/info", diff --git a/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/migrate/10_reindex.yml b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/migrate/10_reindex.yml new file mode 100644 index 000000000000..01a41b3aa8c9 --- /dev/null +++ b/x-pack/plugin/src/yamlRestTest/resources/rest-api-spec/test/migrate/10_reindex.yml @@ -0,0 +1,89 @@ +--- +setup: + - do: + cluster.health: + wait_for_status: yellow + +--- +"Test Reindex With Unsupported Mode": + - do: + catch: /illegal_argument_exception/ + migrate.reindex: + body: | + { + "mode": "unsupported_mode", + "source": { + "index": "my-data-stream" + } + } + +--- +"Test Reindex With Nonexistent Data Stream": + - do: + catch: /resource_not_found_exception/ + migrate.reindex: + body: | + { + "mode": "upgrade", + "source": { + "index": "my-data-stream" + } + } + + - do: + catch: /resource_not_found_exception/ + migrate.reindex: + body: | + { + "mode": "upgrade", + "source": { + "index": "my-data-stream1,my-data-stream2" + } + } + + +--- +"Test Reindex With Bad Data Stream Name": + - do: + catch: /illegal_argument_exception/ + migrate.reindex: + body: | + { + "mode": "upgrade", + "source": { + "index": "my-data-stream*" + } + } + +--- +"Test Reindex With Existing Data Stream": + - do: + indices.put_index_template: + name: my-template1 + body: + index_patterns: [my-data-stream*] + template: + mappings: + properties: + '@timestamp': + type: date + 'foo': + type: keyword + data_stream: {} + + - do: + indices.create_data_stream: + name: my-data-stream + - is_true: acknowledged + +# Uncomment once the cancel API is in place +# - do: +# migrate.reindex: +# body: | +# { +# "mode": "upgrade", +# "source": { +# "index": "my-data-stream" +# } +# } +# - match: { task: "reindex-data-stream-my-data-stream" }