From b5fa56ae67cdb58e49814392590f434ae1331f7f Mon Sep 17 00:00:00 2001 From: Saikat Sarkar Date: Tue, 17 Sep 2024 14:56:52 -0600 Subject: [PATCH] Remove infer tests since they will be covered in the notebook --- .../IbmWatsonxEmbeddingsActionTests.java | 46 --- .../ibmwatsonx/IbmWatsonxServiceTests.java | 334 ------------------ 2 files changed, 380 deletions(-) diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxEmbeddingsActionTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxEmbeddingsActionTests.java index e6b6ce22e2951..d8739ea08075e 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxEmbeddingsActionTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/external/action/ibmwatsonx/IbmWatsonxEmbeddingsActionTests.java @@ -76,52 +76,6 @@ public void shutdown() throws IOException { webServer.close(); } - public void testExecute_ReturnsSuccessfulResponse() throws IOException { - var apiKey = "apiKey"; - var model = "model"; - var input = "input"; - var projectId = "projectId"; - URI uri = URI.create("https://abc.com"); - var apiVersion = "apiVersion"; - - var senderFactory = new HttpRequestSender.Factory(createWithEmptySettings(threadPool), clientManager, mockClusterServiceEmpty()); - - try (var sender = senderFactory.createSender()) { - sender.start(); - - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "abc" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var action = createAction(getUrl(webServer), apiKey, model, projectId, uri, apiVersion, sender); - - PlainActionFuture listener = new PlainActionFuture<>(); - action.execute(new DocumentsOnlyInput(List.of(input)), InferenceAction.Request.DEFAULT_TIMEOUT, listener); - - var result = listener.actionGet(TIMEOUT); - - assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); - assertThat(webServer.requests(), hasSize(1)); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), equalTo(XContentType.JSON.mediaType())); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, aMapWithSize(3)); - assertThat(requestMap, is(Map.of("project_id", "projectId", "inputs", List.of(input), "model_id", "model"))); - } - } - public void testExecute_ThrowsElasticsearchException() { var sender = mock(Sender.class); var apiKey = "apiKey"; diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java index 8735dd7d65fcc..9d10c6bc44847 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/ibmwatsonx/IbmWatsonxServiceTests.java @@ -423,340 +423,6 @@ public void testInfer_ThrowsErrorWhenModelIsNotIbmWatsonxModel() throws IOExcept verifyNoMoreInteractions(sender); } - public void testInfer_SendsEmbeddingsRequest() throws IOException { - var input = "input"; - - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "input" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - getUrl(webServer) - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.infer( - model, - null, - List.of(input), - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - var result = listener.actionGet(TIMEOUT); - - assertThat(result.asMap(), is(buildExpectationFloat(List.of(new float[] { 0.0123F, -0.0123F })))); - assertThat(webServer.requests(), hasSize(1)); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), Matchers.equalTo(XContentType.JSON.mediaType())); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, aMapWithSize(3)); - assertThat(requestMap, Matchers.is(Map.of("project_id", projectId, "inputs", List.of(input), "model_id", modelId))); - } - } - - public void testChunkedInfer_Batches() throws IOException { - var input = List.of("foo", "bar"); - - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "foo" - }, - { - "embedding": [ - 0.0456, - -0.0456 - ], - "input": "bar" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - getUrl(webServer) - ); - PlainActionFuture> listener = new PlainActionFuture<>(); - service.chunkedInfer( - model, - null, - input, - new HashMap<>(), - InputType.INGEST, - new ChunkingOptions(null, null), - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - - var results = listener.actionGet(TIMEOUT); - assertThat(results, hasSize(2)); - - // first result - { - assertThat(results.get(0), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(0); - assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(input.get(0), floatResult.chunks().get(0).matchedText()); - assertTrue(Arrays.equals(new float[] { 0.0123f, -0.0123f }, floatResult.chunks().get(0).embedding())); - } - - // second result - { - assertThat(results.get(1), instanceOf(InferenceChunkedTextEmbeddingFloatResults.class)); - var floatResult = (InferenceChunkedTextEmbeddingFloatResults) results.get(1); - assertThat(floatResult.chunks(), hasSize(1)); - assertEquals(input.get(1), floatResult.chunks().get(0).matchedText()); - assertTrue(Arrays.equals(new float[] { 0.0456f, -0.0456f }, floatResult.chunks().get(0).embedding())); - } - - assertThat(webServer.requests(), hasSize(1)); - assertThat(webServer.requests().get(0).getHeader(HttpHeaders.CONTENT_TYPE), Matchers.equalTo(XContentType.JSON.mediaType())); - - var requestMap = entityAsMap(webServer.requests().get(0).getBody()); - assertThat(requestMap, aMapWithSize(3)); - assertThat(requestMap, is(Map.of("project_id", projectId, "inputs", List.of("foo", "bar"), "model_id", modelId))); - } - } - - public void testInfer_ResourceNotFound() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - - try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) { - - String responseJson = """ - { - "error": { - "message": "error" - } - } - """; - webServer.enqueue(new MockResponse().setResponseCode(404).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - getUrl(webServer) - ); - PlainActionFuture listener = new PlainActionFuture<>(); - service.infer( - model, - null, - List.of("abc"), - new HashMap<>(), - InputType.INGEST, - InferenceAction.Request.DEFAULT_TIMEOUT, - listener - ); - - var error = expectThrows(ElasticsearchException.class, () -> listener.actionGet(TIMEOUT)); - assertThat(error.getMessage(), containsString("Resource not found at ")); - assertThat(error.getMessage(), containsString("Error message: [error]")); - assertThat(webServer.requests(), hasSize(1)); - } - } - - public void testCheckModelConfig_UpdatesDimensions() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var similarityMeasure = SimilarityMeasure.DOT_PRODUCT; - - try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "foo" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - 1, - similarityMeasure - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - // Updates dimensions to two as two embeddings were returned instead of one as specified before - assertThat( - result, - is( - IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - 2, - similarityMeasure - ) - ) - ); - } - } - - public void testCheckModelConfig_UpdatesSimilarityToDotProduct_WhenItIsNull() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var twoDimension = 2; - - try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "foo" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - twoDimension, - null - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - assertThat( - result, - is( - IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - twoDimension, - SimilarityMeasure.DOT_PRODUCT - ) - ) - ); - } - } - - public void testCheckModelConfig_DoesNotUpdateSimilarity_WhenItIsSpecifiedAsCosine() throws IOException { - var senderFactory = HttpRequestSenderTests.createSenderFactory(threadPool, clientManager); - var twoDimension = 2; - - try (var service = new IbmWatsonxService(senderFactory, createWithEmptySettings(threadPool))) { - String responseJson = """ - { - "results": [ - { - "embedding": [ - 0.0123, - -0.0123 - ], - "input": "foo" - } - ] - } - """; - - webServer.enqueue(new MockResponse().setResponseCode(200).setBody(responseJson)); - - var model = IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - twoDimension, - SimilarityMeasure.COSINE - ); - - PlainActionFuture listener = new PlainActionFuture<>(); - service.checkModelConfig(model, listener); - var result = listener.actionGet(TIMEOUT); - - assertThat( - result, - is( - IbmWatsonxEmbeddingsModelTests.createModel( - getUrl(webServer), - modelId, - projectId, - URI.create(url), - apiVersion, - apiKey, - twoDimension, - SimilarityMeasure.COSINE - ) - ) - ); - } - } - private static ActionListener getModelListenerForException(Class exceptionClass, String expectedMessage) { return ActionListener.wrap((model) -> fail("Model parsing should have failed"), e -> { assertThat(e, Matchers.instanceOf(exceptionClass));