From 9992edc588814ff0d577b30303fac73546fa5118 Mon Sep 17 00:00:00 2001 From: Pat Whelan Date: Mon, 28 Oct 2024 14:54:03 -0400 Subject: [PATCH] [ML] Fix stream support for TaskType.ANY (#115656) If we support one, then we support any. --- docs/changelog/115656.yaml | 5 +++++ .../xpack/inference/services/SenderService.java | 3 ++- .../services/amazonbedrock/AmazonBedrockServiceTests.java | 7 +++++++ .../services/anthropic/AnthropicServiceTests.java | 7 +++++++ .../services/azureaistudio/AzureAiStudioServiceTests.java | 7 +++++++ .../services/azureopenai/AzureOpenAiServiceTests.java | 7 +++++++ .../inference/services/cohere/CohereServiceTests.java | 7 +++++++ .../googleaistudio/GoogleAiStudioServiceTests.java | 7 +++++++ .../inference/services/openai/OpenAiServiceTests.java | 7 +++++++ 9 files changed, 56 insertions(+), 1 deletion(-) create mode 100644 docs/changelog/115656.yaml diff --git a/docs/changelog/115656.yaml b/docs/changelog/115656.yaml new file mode 100644 index 0000000000000..13b612b052fc1 --- /dev/null +++ b/docs/changelog/115656.yaml @@ -0,0 +1,5 @@ +pr: 115656 +summary: Fix stream support for `TaskType.ANY` +area: Machine Learning +type: bug +issues: [] diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java index 71b38d7a0785a..953cf4cf6ad77 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/SenderService.java @@ -25,13 +25,14 @@ import org.elasticsearch.xpack.inference.external.http.sender.Sender; import java.io.IOException; +import java.util.EnumSet; import java.util.List; import java.util.Map; import java.util.Objects; import java.util.Set; public abstract class SenderService implements InferenceService { - protected static final Set COMPLETION_ONLY = Set.of(TaskType.COMPLETION); + protected static final Set COMPLETION_ONLY = EnumSet.of(TaskType.COMPLETION, TaskType.ANY); private final Sender sender; private final ServiceComponents serviceComponents; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java index 06c5a68987a9e..931d418a3664b 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/amazonbedrock/AmazonBedrockServiceTests.java @@ -1304,6 +1304,13 @@ public void testInfer_UnauthorizedResponse() throws IOException { } } + public void testSupportsStreaming() throws IOException { + try (var service = new AmazonBedrockService(mock(), mock(), createWithEmptySettings(mock()))) { + assertTrue(service.canStream(TaskType.COMPLETION)); + assertTrue(service.canStream(TaskType.ANY)); + } + } + public void testChunkedInfer_CallsInfer_ConvertsFloatResponse_ForEmbeddings() throws IOException { var model = AmazonBedrockEmbeddingsModelTests.createModel( "id", diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java index 48277112d9306..c4f7fbfb14437 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/anthropic/AnthropicServiceTests.java @@ -593,6 +593,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception { .hasErrorContaining("blah"); } + public void testSupportsStreaming() throws IOException { + try (var service = new AnthropicService(mock(), createWithEmptySettings(mock()))) { + assertTrue(service.canStream(TaskType.COMPLETION)); + assertTrue(service.canStream(TaskType.ANY)); + } + } + private AnthropicService createServiceWithMockSender() { return new AnthropicService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java index e85edf573ba96..4d2eb60767f44 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureaistudio/AzureAiStudioServiceTests.java @@ -1384,6 +1384,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception { .hasErrorContaining("You didn't provide an API key..."); } + public void testSupportsStreaming() throws IOException { + try (var service = new AzureAiStudioService(mock(), createWithEmptySettings(mock()))) { + assertTrue(service.canStream(TaskType.COMPLETION)); + assertTrue(service.canStream(TaskType.ANY)); + } + } + // ---------------------------------------------------------------- private AzureAiStudioService createService() { diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java index 3408fc358cac0..1bae6ce66d6aa 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/azureopenai/AzureOpenAiServiceTests.java @@ -1504,6 +1504,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception { .hasErrorContaining("You didn't provide an API key..."); } + public void testSupportsStreaming() throws IOException { + try (var service = new AzureOpenAiService(mock(), createWithEmptySettings(mock()))) { + assertTrue(service.canStream(TaskType.COMPLETION)); + assertTrue(service.canStream(TaskType.ANY)); + } + } + private AzureOpenAiService createAzureOpenAiService() { return new AzureOpenAiService(mock(HttpRequestSender.Factory.class), createWithEmptySettings(threadPool)); } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java index 758c38166778b..d44be4246f844 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/cohere/CohereServiceTests.java @@ -1683,6 +1683,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception { .hasErrorContaining("how dare you"); } + public void testSupportsStreaming() throws IOException { + try (var service = new CohereService(mock(), createWithEmptySettings(mock()))) { + assertTrue(service.canStream(TaskType.COMPLETION)); + assertTrue(service.canStream(TaskType.ANY)); + } + } + private Map getRequestConfigMap( Map serviceSettings, Map taskSettings, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java index e8382876868c5..27a53177658c6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/googleaistudio/GoogleAiStudioServiceTests.java @@ -1219,6 +1219,13 @@ private void testUpdateModelWithEmbeddingDetails_Successful(SimilarityMeasure si } } + public void testSupportsStreaming() throws IOException { + try (var service = new GoogleAiStudioService(mock(), createWithEmptySettings(mock()))) { + assertTrue(service.canStream(TaskType.COMPLETION)); + assertTrue(service.canStream(TaskType.ANY)); + } + } + public static Map buildExpectationCompletions(List completions) { return Map.of( ChatCompletionResults.COMPLETION, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java index cf1438b334478..0698b9652b767 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/openai/OpenAiServiceTests.java @@ -1077,6 +1077,13 @@ public void testInfer_StreamRequest_ErrorResponse() throws Exception { .hasErrorContaining("You didn't provide an API key..."); } + public void testSupportsStreaming() throws IOException { + try (var service = new OpenAiService(mock(), createWithEmptySettings(mock()))) { + assertTrue(service.canStream(TaskType.COMPLETION)); + assertTrue(service.canStream(TaskType.ANY)); + } + } + public void testCheckModelConfig_IncludesMaxTokens() throws IOException { var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager);