From de10a5eb95f3aff55545dd79e44b8209b9821d0b Mon Sep 17 00:00:00 2001 From: Saikat Sarkar Date: Tue, 17 Sep 2024 14:56:52 -0600 Subject: [PATCH] Remove infer tests since they will be covered in the notebook --- .../IbmWatsonxEmbeddingsActionTests.java | 57 --- .../ibmwatsonx/IbmWatsonxServiceTests.java | 351 ------------------ 2 files changed, 408 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxEmbeddingsActionTests.java index e6b6ce22e2951..81128b48cd41c 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxEmbeddingsActionTests.java @@ -7,7 +7,6 @@ package org.elasticsearch.xpack.inference.external.action.ibmwatsonx; -import org.apache.http.HttpHeaders; import org.elasticsearch.ElasticsearchException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; @@ -15,17 +14,14 @@ import org.elasticsearch.core.TimeValue; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; import org.elasticsearch.xpack.inference.common.TruncatorTests; import org.elasticsearch.xpack.inference.external.action.ExecutableAction; import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.DocumentsOnlyInput; -import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; import org.elasticsearch.xpack.inference.external.http.sender.IbmWatsonxEmbeddingsRequestManager; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; @@ -35,21 +31,14 @@ import java.io.IOException; import java.net.URI; import java.util.List; -import java.util.Map; import java.util.concurrent.TimeUnit; import static org.elasticsearch.core.Strings.format; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage; -import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; -import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModelTests.createModel; -import static org.hamcrest.Matchers.aMapWithSize; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.is; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; @@ -76,52 +65,6 @@ public void shutdown() throws IOException { webServer.close(); } - public void testExecute_ReturnsSuccessfulResponse() throws IOException { - var apiKey = "apiKey"; - var model = "model"; - var input = "input"; - var projectId = "projectId"; - URI uri = URI.create("https://abc.com"); - var apiVersion = "apiVersion"; - - var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); - - try (var sender = senderFactory.createSender()) { - sender.start(); - - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "abc" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var action = createAction(getUrl(webServer), apiKey, model, projectId, uri, apiVersion, sender); - - PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(input)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); - - var result = listener.actionGet(TIMEOUT); - - assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); - assertThat(webServer.requests(), hasSize(1)); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, aMapWithSize(3)); - assertThat(requestMap, is(Map.of("project_id", "projectId", "inputs", List.of(input), "model_id", "model"))); - } - } - public void testExecute_ThrowsElasticsearchException() { var sender = mock(Sender.class); var apiKey = "apiKey"; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 8735dd7d65fcc..a00d7ffe3c9d6 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -7,37 +7,27 @@ package org.elasticsearch.xpack.inference.services.ibmwatsonx; -import org.apache.http.HttpHeaders; -import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.core.TimeValue; -import org.elasticsearch.inference.ChunkedInferenceServiceResults; -import org.elasticsearch.inference.ChunkingOptions; import org.elasticsearch.inference.EmptyTaskSettings; import org.elasticsearch.inference.InferenceServiceResults; import org.elasticsearch.inference.InputType; import org.elasticsearch.inference.Model; import org.elasticsearch.inference.ModelConfigurations; -import org.elasticsearch.inference.SimilarityMeasure; import org.elasticsearch.inference.TaskType; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.http.MockResponse; import org.elasticsearch.test.http.MockWebServer; import org.elasticsearch.threadpool.ThreadPool; -import org.elasticsearch.xcontent.XContentType; import org.elasticsearch.xpack.core.inference.action.InferenceAction; -import org.elasticsearch.xpack.core.inference.results.InferenceChunkedTextEmbeddingFloatResults; import org.elasticsearch.xpack.inference.external.http.HttpClientManager; import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender; -import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSenderTests; import org.elasticsearch.xpack.inference.external.http.sender.Sender; import org.elasticsearch.xpack.inference.logging.ThrottlerManager; import org.elasticsearch.xpack.inference.services.ServiceFields; import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModel; -import org.elasticsearch.xpack.inference.services.ibmwatsonx.embeddings.IbmWatsonxEmbeddingsModelTests; import org.hamcrest.MatcherAssert; import org.hamcrest.Matchers; import org.junit.After; @@ -45,7 +35,6 @@ import java.io.IOException; import java.net.URI; -import java.util.Arrays; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -56,17 +45,11 @@ import static org.elasticsearch.xpack.inference.Utils.getPersistedConfigMap; import static org.elasticsearch.xpack.inference.Utils.inferenceUtilityPool; import static org.elasticsearch.xpack.inference.Utils.mockClusterServiceEmpty; -import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; -import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; -import static org.elasticsearch.xpack.inference.results.TextEmbeddingResultsTests.buildExpectationFloat; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; import static org.elasticsearch.xpack.inference.services.cohere.embeddings.CohereEmbeddingsTaskSettingsTests.getTaskSettingsMapEmpty; import static org.elasticsearch.xpack.inference.services.settings.DefaultSecretSettingsTests.getSecretSettingsMap; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.is; -import static org.hamcrest.Matchers.aMapWithSize; -import static org.hamcrest.Matchers.containsString; -import static org.hamcrest.Matchers.hasSize; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; @@ -423,340 +406,6 @@ public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOExcept verifyNoMoreInteractions(sender); } - public void testInfer_SendsEmbeddingsRequest() throws IOException { - var input = "input"; - - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "input" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - getUrl(webServer) - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.infer( - model, - null, - List.of(input), - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - var result = listener.actionGet(TIMEOUT); - - assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); - assertThat(webServer.requests(), hasSize(1)); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), Matchers.equalTo(XContentType.JSON.mediaType())); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, aMapWithSize(3)); - assertThat(requestMap, Matchers.is(Map.of("project_id", projectId, "inputs", List.of(input), "model_id", modelId))); - } - } - - public void testChunkedInfer_Batches() throws IOException { - var input = List.of("foo", "bar"); - - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "foo" - }, - { - "embedding": [ - 0.0456, - -0.0456 - ], - "input": "bar" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - getUrl(webServer) - ); - PlainActionFuture> listener = new PlainActionFuture<>(); - service.chunkedInfer( - model, - null, - input, - new HashMap<>(), - InputType.INGEST, - new ChunkingOptions(null, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - - var results = listener.actionGet(TIMEOUT); - assertThat(results, hasSize(2)); - - // first result - { - assertThat(results.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); - assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(input.get(0), floatResult.chunks().get(0).matchedText()); - assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding())); - } - - // second result - { - assertThat(results.get(1), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); - assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(input.get(1), floatResult.chunks().get(0).matchedText()); - assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding())); - } - - assertThat(webServer.requests(), hasSize(1)); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), Matchers.equalTo(XContentType.JSON.mediaType())); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, aMapWithSize(3)); - assertThat(requestMap, is(Map.of("project_id", projectId, "inputs", List.of("foo", "bar"), "model_id", modelId))); - } - } - - public void testInfer_ResourceNotFound() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "error": { - "message": "error" - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - getUrl(webServer) - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.infer( - model, - null, - List.of("abc"), - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - - var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(error.getMessage(), containsString("Resource not found at ")); - assertThat(error.getMessage(), containsString("Error message: [error]")); - assertThat(webServer.requests(), hasSize(1)); - } - } - - public void testCheckModelConfig_UpdatesDimensions() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var similarityMeasure = SimilarityMeasure.DOT_PRODUCT; - - try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "foo" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - 1, - similarityMeasure - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - // Updates dimensions to two as two embeddings were returned instead of one as specified before - assertThat( - result, - is( - IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - 2, - similarityMeasure - ) - ) - ); - } - } - - public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var twoDimension = 2; - - try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "foo" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - twoDimension, - null - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - assertThat( - result, - is( - IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - twoDimension, - SimilarityMeasure.DOT_PRODUCT - ) - ) - ); - } - } - - public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var twoDimension = 2; - - try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "foo" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - twoDimension, - SimilarityMeasure.COSINE - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - assertThat( - result, - is( - IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - twoDimension, - SimilarityMeasure.COSINE - ) - ) - ); - } - } - private static ActionListener getModelListenerForException(Class exceptionClass, String expectedMessage) { return ActionListener.wrap((model) -> fail("Model parsing should have failed"), e -> { assertThat(e, Matchers.instanceOf(exceptionClass));