From 1e2c19fb0be8f73caec9b7bb693b89d120020127 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Thu, 26 Sep 2024 12:18:42 -0400 Subject: [PATCH] [ML] Add stream flag to inference providers (#113424) Pass the stream flag from the REST request through to the inference providers via the InferenceInputs. Co-authored-by: Elastic Machine --- .../inference/InferenceService.java | 2 ++ .../TestDenseInferenceServiceExtension.java | 1 + .../mock/TestRerankingServiceExtension.java | 1 + .../TestSparseInferenceServiceExtension.java | 1 + ...stStreamingCompletionServiceExtension.java | 1 + .../action/TransportInferenceAction.java | 1 + .../inference/external/http/HttpUtils.java | 2 +- .../http/sender/DocumentsOnlyInput.java | 14 +++++++++++-- .../http/sender/QueryAndDocsInputs.java | 21 +++++++++++++------ .../inference/services/SenderService.java | 5 +++-- .../inference/services/ServiceUtils.java | 1 + .../AlibabaCloudSearchService.java | 1 + .../ElasticsearchInternalService.java | 1 + .../services/elser/ElserInternalService.java | 1 + .../SimpleServiceIntegrationValidator.java | 1 + .../inference/services/ServiceUtilsTests.java | 21 ++++++++----------- .../AmazonBedrockServiceTests.java | 4 ++++ .../anthropic/AnthropicServiceTests.java | 2 ++ .../AzureAiStudioServiceTests.java | 3 +++ .../azureopenai/AzureOpenAiServiceTests.java | 3 +++ .../services/cohere/CohereServiceTests.java | 6 ++++++ .../elastic/ElasticInferenceServiceTests.java | 2 ++ .../GoogleAiStudioServiceTests.java | 4 ++++ .../HuggingFaceBaseServiceTests.java | 1 + .../huggingface/HuggingFaceServiceTests.java | 2 ++ .../ibmwatsonx/IbmWatsonxServiceTests.java | 3 +++ .../services/mistral/MistralServiceTests.java | 2 ++ .../services/openai/OpenAiServiceTests.java | 3 +++ ...impleServiceIntegrationValidatorTests.java | 5 ++++- 29 files changed, 91 insertions(+), 24 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/inference/InferenceService.java b/server/src/main/java/org/elasticsearch/inference/InferenceService.java index f677f75dfb5a..854c58b4f57a 100644 --- a/server/src/main/java/org/elasticsearch/inference/InferenceService.java +++ b/server/src/main/java/org/elasticsearch/inference/InferenceService.java @@ -85,6 +85,7 @@ void parseRequestConfig( * @param model The model * @param query Inference query, mainly for re-ranking * @param input Inference input + * @param stream Stream inference results * @param taskSettings Settings in the request to override the model's defaults * @param inputType For search, ingest etc * @param timeout The timeout for the request @@ -94,6 +95,7 @@ void infer( Model model, @Nullable String query, List input, + boolean stream, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java index 10d8f90efef5..daa29d33699e 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestDenseInferenceServiceExtension.java @@ -94,6 +94,7 @@ public void infer( Model model, @Nullable String query, List input, + boolean stream, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java index fae11d5b53ca..1894db6db8df 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestRerankingServiceExtension.java @@ -85,6 +85,7 @@ public void infer( Model model, @Nullable String query, List input, + boolean stream, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java index fee9855b188c..1a5df146a3aa 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestSparseInferenceServiceExtension.java @@ -88,6 +88,7 @@ public void infer( Model model, @Nullable String query, List input, + boolean stream, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java index 3d72b1f2729b..4313026e9252 100644 --- a/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java +++ b/x-pack/plugin/inference/qa/test-service-plugin/src/main/java/org/elasticsearch/xpack/inference/mock/TestStreamingCompletionServiceExtension.java @@ -85,6 +85,7 @@ public void infer( Model model, String query, List input, + boolean stream, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java index 803e8f1e0761..4186b281a35b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/TransportInferenceAction.java @@ -114,6 +114,7 @@ private void inferOnService( model, request.getQuery(), request.getInput(), + request.isStreaming(), request.getTaskSettings(), request.getInputType(), request.getInferenceTimeout(), diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpUtils.java index 9f2ceddc92a2..4282e5d1e7cb 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/HttpUtils.java @@ -46,7 +46,7 @@ private static String getStatusCodeErrorMessage(Request request, HttpResult resu } public static void checkForEmptyBody(ThrottlerManager throttlerManager, Logger logger, Request request, HttpResult result) { - if (result.isBodyEmpty()) { + if (result.isBodyEmpty() && (request.isStreaming() == false)) { String message = format("Response body was empty for request from inference entity id [%s]", request.getInferenceEntityId()); throttlerManager.warn(logger, message); throw new IllegalStateException(message); diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java index a32e2018117f..8cf411d84c93 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/DocumentsOnlyInput.java @@ -21,13 +21,23 @@ public static DocumentsOnlyInput of(InferenceInputs inferenceInputs) { } private final List input; + private final boolean stream; - public DocumentsOnlyInput(List chunks) { + public DocumentsOnlyInput(List input) { + this(input, false); + } + + public DocumentsOnlyInput(List input, boolean stream) { super(); - this.input = Objects.requireNonNull(chunks); + this.input = Objects.requireNonNull(input); + this.stream = stream; } public List getInputs() { return this.input; } + + public boolean stream() { + return stream; + } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java index 0d5f98c180ba..50bb77b307db 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/external/http/sender/QueryAndDocsInputs.java @@ -21,6 +21,19 @@ public static QueryAndDocsInputs of(InferenceInputs inferenceInputs) { } private final String query; + private final List chunks; + private final boolean stream; + + public QueryAndDocsInputs(String query, List chunks) { + this(query, chunks, false); + } + + public QueryAndDocsInputs(String query, List chunks, boolean stream) { + super(); + this.query = Objects.requireNonNull(query); + this.chunks = Objects.requireNonNull(chunks); + this.stream = stream; + } public String getQuery() { return query; @@ -30,12 +43,8 @@ public List getChunks() { return chunks; } - List chunks; - - public QueryAndDocsInputs(String query, List chunks) { - super(); - this.query = Objects.requireNonNull(query); - this.chunks = Objects.requireNonNull(chunks); + public boolean stream() { + return stream; } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 864aebcef124..21b2df6af1ab 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -51,6 +51,7 @@ public void infer( Model model, @Nullable String query, List input, + boolean stream, Map taskSettings, InputType inputType, TimeValue timeout, @@ -58,9 +59,9 @@ public void infer( ) { init(); if (query != null) { - doInfer(model, new QueryAndDocsInputs(query, input), taskSettings, inputType, timeout, listener); + doInfer(model, new QueryAndDocsInputs(query, input, stream), taskSettings, inputType, timeout, listener); } else { - doInfer(model, new DocumentsOnlyInput(input), taskSettings, inputType, timeout, listener); + doInfer(model, new DocumentsOnlyInput(input, stream), taskSettings, inputType, timeout, listener); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java index 6eb033191300..32c1d17373e5 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/ServiceUtils.java @@ -666,6 +666,7 @@ public static void getEmbeddingSize(Model model, InferenceService service, Actio model, null, List.of(TEST_EMBEDDING_INPUT), + false, Map.of(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java index 8f0c9896c664..994bad194aef 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/alibabacloudsearch/AlibabaCloudSearchService.java @@ -309,6 +309,7 @@ private void checkAlibabaCloudSearchServiceConfig(Model model, InferenceService model, query, List.of(input), + false, Map.of(), InputType.INGEST, DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java index cca8ae63e974..93408c067098 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elasticsearch/ElasticsearchInternalService.java @@ -323,6 +323,7 @@ public void infer( Model model, @Nullable String query, List input, + boolean stream, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java index 948117954a63..746cb6e89fad 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/elser/ElserInternalService.java @@ -149,6 +149,7 @@ public void infer( Model model, @Nullable String query, List inputs, + boolean stream, Map taskSettings, InputType inputType, TimeValue timeout, diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java index 6233a7e0b6b2..70f01e77b936 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidator.java @@ -31,6 +31,7 @@ public void validate(InferenceService service, Model model, ActionListener { - @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocation.getArguments()[6]; + ActionListener listener = invocation.getArgument(7); listener.onResponse(new InferenceTextEmbeddingFloatResults(List.of())); return Void.TYPE; - }).when(service).infer(any(), any(), any(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); getEmbeddingSize(model, service, listener); @@ -878,12 +878,11 @@ public void testGetEmbeddingSize_ReturnsError_WhenTextEmbeddingByteResults_IsEmp when(model.getTaskType()).thenReturn(TaskType.TEXT_EMBEDDING); doAnswer(invocation -> { - @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocation.getArguments()[6]; + ActionListener listener = invocation.getArgument(7); listener.onResponse(new InferenceTextEmbeddingByteResults(List.of())); return Void.TYPE; - }).when(service).infer(any(), any(), any(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); getEmbeddingSize(model, service, listener); @@ -903,12 +902,11 @@ public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingResults() { var textEmbedding = TextEmbeddingResultsTests.createRandomResults(); doAnswer(invocation -> { - @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocation.getArguments()[6]; + ActionListener listener = invocation.getArgument(7); listener.onResponse(textEmbedding); return Void.TYPE; - }).when(service).infer(any(), any(), any(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); getEmbeddingSize(model, service, listener); @@ -927,12 +925,11 @@ public void testGetEmbeddingSize_ReturnsSize_ForTextEmbeddingByteResults() { var textEmbedding = InferenceTextEmbeddingByteResultsTests.createRandomResults(); doAnswer(invocation -> { - @SuppressWarnings("unchecked") - ActionListener listener = (ActionListener) invocation.getArguments()[6]; + ActionListener listener = invocation.getArgument(7); listener.onResponse(textEmbedding); return Void.TYPE; - }).when(service).infer(any(), any(), any(), any(), any(), any(), any()); + }).when(service).infer(any(), any(), any(), anyBoolean(), any(), any(), any(), any()); PlainActionFuture listener = new PlainActionFuture<>(); getEmbeddingSize(model, service, listener); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index bbf34354e181..297a42f9d1fa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -671,6 +671,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAmazonBedrockModel() throws IOExc mockModel, null, List.of(""), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -721,6 +722,7 @@ public void testInfer_SendsRequest_ForEmbeddingsModel() throws IOException { model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -762,6 +764,7 @@ public void testInfer_SendsRequest_ForChatCompletionModel() throws IOException { model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1025,6 +1028,7 @@ public void testInfer_UnauthorizedResponse() throws IOException { model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index 5e32344ab384..c3693c227c43 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -452,6 +452,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException mockModel, null, List.of(""), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -506,6 +507,7 @@ public void testInfer_SendsCompletionRequest() throws IOException { model, null, List.of("input"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index 6f33c36f42db..bb736f592fbd 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -825,6 +825,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureAiStudioModel() throws IOExc mockModel, null, List.of(""), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -954,6 +955,7 @@ public void testInfer_WithChatCompletionModel() throws IOException { model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1004,6 +1006,7 @@ public void testInfer_UnauthorisedResponse() throws IOException { model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index b3fbd6fc9b42..142877c09180 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -601,6 +601,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAzureOpenAiModel() throws IOExcep mockModel, null, List.of(""), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -656,6 +657,7 @@ public void testInfer_SendsRequest() throws IOException, URISyntaxException { model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1051,6 +1053,7 @@ public void testInfer_UnauthorisedResponse() throws IOException, URISyntaxExcept model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index aebc3e3776c4..a577a6664d39 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -622,6 +622,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotCohereModel() throws IOException mockModel, null, List.of(""), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -689,6 +690,7 @@ public void testInfer_SendsRequest() throws IOException { model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -932,6 +934,7 @@ public void testInfer_UnauthorisedResponse() throws IOException { model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -991,6 +994,7 @@ public void testInfer_SetsInputTypeToIngest_FromInferParameter_WhenTaskSettingsA model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1064,6 +1068,7 @@ public void testInfer_SetsInputTypeToIngestFromInferParameter_WhenModelSettingIs model, null, List.of("abc"), + false, CohereEmbeddingsTaskSettingsTests.getTaskSettingsMap(InputType.SEARCH, null), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1135,6 +1140,7 @@ public void testInfer_DoesNotSetInputType_WhenNotPresentInTaskSettings_AndUnspec model, null, List.of("abc"), + false, new HashMap<>(), InputType.UNSPECIFIED, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java index 38124b3401aa..0bbf2be7301d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/elastic/ElasticInferenceServiceTests.java @@ -346,6 +346,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotAValidModel() throws IOException mockModel, null, List.of(""), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -397,6 +398,7 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { model, null, List.of("input text"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index 89d6a010bbc0..5d79d0e01f40 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -503,6 +503,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotGoogleAiStudioModel() throws IOEx mockModel, null, List.of(""), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -578,6 +579,7 @@ public void testInfer_SendsCompletionRequest() throws IOException { model, null, List.of("input"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -634,6 +636,7 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { model, null, List.of(input), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -775,6 +778,7 @@ public void testInfer_ResourceNotFound() throws IOException { model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java index 22c3b7895460..168110ae8f7c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceBaseServiceTests.java @@ -69,6 +69,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotHuggingFaceModel() throws IOExcep mockModel, null, List.of(""), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java index 5ea9f82e5b60..d13dea2ab6b4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/huggingface/HuggingFaceServiceTests.java @@ -438,6 +438,7 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -481,6 +482,7 @@ public void testInfer_SendsElserRequest() throws IOException { model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index e0936c778c7a..a2de7c15d54d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -409,6 +409,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOExcept mockModel, null, List.of(""), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -465,6 +466,7 @@ public void testInfer_SendsEmbeddingsRequest() throws IOException { model, null, List.of(input), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -588,6 +590,7 @@ public void testInfer_ResourceNotFound() throws IOException { model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java index 9d0fd954c44f..33a2b43caf17 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/mistral/MistralServiceTests.java @@ -446,6 +446,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotMistralEmbeddingsModel() throws I mockModel, null, List.of(""), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -571,6 +572,7 @@ public void testInfer_UnauthorisedResponse() throws IOException { model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index a5e8c1d7eb26..32099c4bd0be 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -936,6 +936,7 @@ public void testInfer_ThrowsErrorWhenModelIsNotOpenAiModel() throws IOException mockModel, null, List.of(""), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -990,6 +991,7 @@ public void testInfer_SendsRequest() throws IOException { model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, @@ -1470,6 +1472,7 @@ public void testInfer_UnauthorisedResponse() throws IOException { model, null, List.of("abc"), + false, new HashMap<>(), InputType.INGEST, InferenceAction.Request.DEFAULT_TIMEOUT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java index ef295e4070cc..767dd4d64a7d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/validation/SimpleServiceIntegrationValidatorTests.java @@ -64,6 +64,7 @@ public void testValidate_ServiceThrowsException() { eq(mockModel), eq(null), eq(TEST_INPUT), + eq(false), eq(Map.of()), eq(InputType.INGEST), eq(InferenceAction.Request.DEFAULT_TIMEOUT), @@ -94,7 +95,7 @@ public void testValidate_SuccessfulCallToServiceForReRankTaskType() { private void mockSuccessfulCallToService(String query, InferenceServiceResults result) { doAnswer(ans -> { - ActionListener responseListener = ans.getArgument(6); + ActionListener responseListener = ans.getArgument(7); responseListener.onResponse(result); return null; }).when(mockInferenceService) @@ -102,6 +103,7 @@ private void mockSuccessfulCallToService(String query, InferenceServiceResults r eq(mockModel), eq(query), eq(TEST_INPUT), + eq(false), eq(Map.of()), eq(InputType.INGEST), eq(InferenceAction.Request.DEFAULT_TIMEOUT), @@ -117,6 +119,7 @@ private void verifyCallToService(boolean withQuery) { eq(mockModel), eq(withQuery ? TEST_QUERY : null), eq(TEST_INPUT), + eq(false), eq(Map.of()), eq(InputType.INGEST), eq(InferenceAction.Request.DEFAULT_TIMEOUT),