From e4a6a678717cbabc8d79850871b460c70e1049b7 Mon Sep 17 00:00:00 2001 From: carlosdelest Date: Tue, 9 Jan 2024 12:02:21 +0100 Subject: [PATCH] Added test cases --- .../action/bulk/TransportBulkAction.java | 4 +- .../TransportBulkActionInferenceTests.java | 173 ++++++++++++------ 2 files changed, 119 insertions(+), 58 deletions(-) diff --git a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java index 39cc586118eb5..39b45f92dc6da 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/TransportBulkAction.java @@ -723,6 +723,7 @@ private void performInferenceAndExecute(BulkShardRequest bulkShardRequest, Clust // No inference fields? Just execute the request if (fieldsForModels.isEmpty()) { executeBulkShardRequest(bulkShardRequest, releaseOnFinish); + return; } Runnable onInferenceComplete = () -> { @@ -892,9 +893,6 @@ private static String findMapValue(Map map, String... path) { private void executeBulkShardRequest(BulkShardRequest bulkShardRequest, Releasable releaseOnFinish) { if (bulkShardRequest.items().length == 0) { // No requests to execute due to previous errors, terminate early - listener.onResponse( - new BulkResponse(responses.toArray(new BulkItemResponse[responses.length()]), buildTookInMillis(startTimeNanos)) - ); releaseOnFinish.close(); return; } diff --git a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java index 37889ea6a6739..71fd62df1620d 100644 --- a/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java +++ b/server/src/test/java/org/elasticsearch/action/bulk/TransportBulkActionInferenceTests.java @@ -38,7 +38,6 @@ import org.elasticsearch.ingest.IngestService; import org.elasticsearch.test.ClusterServiceUtils; import org.elasticsearch.test.ESTestCase; -import org.elasticsearch.test.transport.CapturingTransport; import org.elasticsearch.threadpool.TestThreadPool; import org.elasticsearch.threadpool.ThreadPool; import org.elasticsearch.transport.TransportService; @@ -46,14 +45,14 @@ import org.junit.Before; import org.mockito.verification.VerificationMode; +import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; import java.util.concurrent.TimeUnit; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.Matchers.notNullValue; +import static org.hamcrest.Matchers.equalTo; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.ArgumentMatchers.eq; @@ -72,8 +71,6 @@ public class TransportBulkActionInferenceTests extends ESTestCase { private static final String INFERENCE_FIELD_2_MODEL_A = "inference_field_2_model_a"; public static final String MODEL_B_ID = "model_b_id"; private static final String INFERENCE_FIELD_MODEL_B = "inference_field_model_b"; - private TransportService transportService; - private CapturingTransport capturingTransport; private ClusterService clusterService; private ThreadPool threadPool; private NodeClient nodeClient; @@ -82,13 +79,9 @@ public class TransportBulkActionInferenceTests extends ESTestCase { @Before public void setup() { threadPool = new TestThreadPool(getClass().getName()); - nodeClient = mock(NodeClient.class); - DiscoveryNodes nodes = mock(DiscoveryNodes.class); - DiscoveryNode remoteNode = mock(DiscoveryNode.class); - Map ingestNodes = Map.of("node", remoteNode); - when(nodes.getIngestNodes()).thenReturn(ingestNodes); + // Contains the fields for models for the index Metadata metadata = Metadata.builder() .indices( Map.of( @@ -118,30 +111,31 @@ public void setup() { clusterService = ClusterServiceUtils.createClusterService(state, threadPool); - capturingTransport = new CapturingTransport(); - transportService = capturingTransport.createTransportService( - clusterService.getSettings(), - threadPool, - TransportService.NOOP_TRANSPORT_INTERCEPTOR, - boundAddress -> clusterService.localNode(), - null, - Collections.emptySet() - ); - transportService.start(); - transportService.acceptIncomingRequests(); - - IngestService ingestService = mock(IngestService.class); transportBulkAction = new TransportBulkAction( threadPool, - transportService, + mock(TransportService.class), clusterService, - ingestService, + mock(IngestService.class), nodeClient, new ActionFilters(Collections.emptySet()), TestIndexNameExpressionResolver.newInstance(), new IndexingPressure(Settings.builder().put(AutoCreateIndex.AUTO_CREATE_INDEX_SETTING.getKey(), true).build()), EmptySystemIndices.INSTANCE ); + + // Default answers to avoid hanging tests due to unexpected invocations + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("Unexpected invocation")); + return Void.TYPE; + }).when(nodeClient).execute(eq(InferenceAction.INSTANCE), any(), any()); + when(nodeClient.executeLocally(eq(TransportShardBulkAction.TYPE), any(), any())).thenAnswer(invocation -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("Unexpected invocation")); + return null; + }); } @After @@ -158,13 +152,14 @@ public void testBulkRequestWithoutInference() { indexRequest.source("non_inference_field", "text", "another_non_inference_field", "other text"); bulkRequest.add(indexRequest); - expectTransportShardBulkActionRequest(); + expectTransportShardBulkActionRequest(1); PlainActionFuture future = new PlainActionFuture<>(); ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); BulkResponse response = future.actionGet(); - assertEquals(1, response.getItems().length); + assertThat(response.getItems().length, equalTo(1)); + assertTrue(Arrays.stream(response.getItems()).allMatch(r -> r.isFailed() == false)); verifyInferenceExecuted(never()); } @@ -175,15 +170,16 @@ public void testBulkRequestWithInference() { indexRequest.source(INFERENCE_FIELD_1_MODEL_A, inferenceFieldText, "non_inference_field", "other text"); bulkRequest.add(indexRequest); - expectInferenceRequest(Map.of(MODEL_A_ID, Set.of(inferenceFieldText))); + expectInferenceRequest(MODEL_A_ID, inferenceFieldText); - expectTransportShardBulkActionRequest(); + expectTransportShardBulkActionRequest(1); PlainActionFuture future = new PlainActionFuture<>(); ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); BulkResponse response = future.actionGet(); - assertEquals(1, response.getItems().length); + assertThat(response.getItems().length, equalTo(1)); + assertTrue(Arrays.stream(response.getItems()).allMatch(r -> r.isFailed() == false)); verifyInferenceExecuted(times(1)); } @@ -205,47 +201,102 @@ public void testBulkRequestWithMultipleFieldsInference() { ); bulkRequest.add(indexRequest); - expectInferenceRequest( - Map.of(MODEL_A_ID, Set.of(inferenceField1Text, inferenceField2Text), MODEL_B_ID, Set.of(inferenceField3Text)) - ); + expectInferenceRequest(MODEL_A_ID, inferenceField1Text, inferenceField2Text); + expectInferenceRequest(MODEL_B_ID, inferenceField3Text); - expectTransportShardBulkActionRequest(); + expectTransportShardBulkActionRequest(1); PlainActionFuture future = new PlainActionFuture<>(); ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); BulkResponse response = future.actionGet(); - assertEquals(1, response.getItems().length); + assertThat(response.getItems().length, equalTo(1)); + assertTrue(Arrays.stream(response.getItems()).allMatch(r -> r.isFailed() == false)); verifyInferenceExecuted(times(2)); } - private void verifyInferenceExecuted(VerificationMode times) { - verify(nodeClient, times).execute(eq(InferenceAction.INSTANCE), any(InferenceAction.Request.class), any()); + public void testBulkRequestWithMultipleDocs() { + BulkRequest bulkRequest = new BulkRequest(); + IndexRequest indexRequest = new IndexRequest(INDEX_NAME).id("id1"); + String inferenceFieldTextDoc1 = "some text"; + bulkRequest.add(indexRequest); + indexRequest.source(INFERENCE_FIELD_1_MODEL_A, inferenceFieldTextDoc1, "non_inference_field", "other text"); + indexRequest = new IndexRequest(INDEX_NAME).id("id2"); + String inferenceFieldTextDoc2 = "some other text"; + indexRequest.source(INFERENCE_FIELD_1_MODEL_A, inferenceFieldTextDoc2, "non_inference_field", "more text"); + bulkRequest.add(indexRequest); + + expectInferenceRequest(MODEL_A_ID, inferenceFieldTextDoc1); + expectInferenceRequest(MODEL_A_ID, inferenceFieldTextDoc2); + + expectTransportShardBulkActionRequest(2); + + PlainActionFuture future = new PlainActionFuture<>(); + ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); + BulkResponse response = future.actionGet(); + + assertThat(response.getItems().length, equalTo(2)); + assertTrue(Arrays.stream(response.getItems()).allMatch(r -> r.isFailed() == false)); + verifyInferenceExecuted(times(2)); } - private void expectTransportShardBulkActionRequest() { - doAnswer(invocation -> { + public void testFailingInference() { + BulkRequest bulkRequest = new BulkRequest(); + IndexRequest indexRequest = new IndexRequest(INDEX_NAME).id("id1"); + String inferenceFieldTextDoc1 = "some text"; + indexRequest.source(INFERENCE_FIELD_1_MODEL_A, inferenceFieldTextDoc1, "non_inference_field", "more text"); + bulkRequest.add(indexRequest); + indexRequest = new IndexRequest(INDEX_NAME).id("id1"); + String inferenceFieldTextDoc2 = "some text"; + indexRequest.source(INFERENCE_FIELD_MODEL_B, inferenceFieldTextDoc2, "non_inference_field", "more text"); + bulkRequest.add(indexRequest); + + expectInferenceRequestFails(MODEL_A_ID, inferenceFieldTextDoc1); + expectInferenceRequest(MODEL_B_ID, inferenceFieldTextDoc2); + + // Only non-failing inference requests will be executed + expectTransportShardBulkActionRequest(1); + + PlainActionFuture future = new PlainActionFuture<>(); + ActionTestUtils.execute(transportBulkAction, null, bulkRequest, future); + BulkResponse response = future.actionGet(); + + assertThat(response.getItems().length, equalTo(2)); + assertTrue(response.getItems()[0].isFailed()); + assertFalse(response.getItems()[1].isFailed()); + verifyInferenceExecuted(times(2)); + } + + private void verifyInferenceExecuted(VerificationMode verificationMode) { + verify(nodeClient, verificationMode).execute(eq(InferenceAction.INSTANCE), any(InferenceAction.Request.class), any()); + } + + private void expectTransportShardBulkActionRequest(int requestSize) { + when(nodeClient.executeLocally(eq(TransportShardBulkAction.TYPE), argThat(r -> matchBulkShardRequest(r, requestSize)), any())) + .thenAnswer(invocation -> { @SuppressWarnings("unchecked") var listener = (ActionListener) invocation.getArguments()[2]; + var bulkShardRequest = (BulkShardRequest) invocation.getArguments()[1]; ShardId shardId = new ShardId(INDEX_NAME, "UUID", 0); - BulkItemResponse successResponse = BulkItemResponse.success( - 0, - DocWriteRequest.OpType.INDEX, - new IndexResponse(shardId, "id", 0, 0, 0, true) - ); - listener.onResponse(new BulkShardResponse(shardId, new BulkItemResponse[] { successResponse })); + BulkItemResponse[] bulkItemResponses = Arrays.stream(bulkShardRequest.items()).map(item -> BulkItemResponse.success(item.id(), DocWriteRequest.OpType.INDEX, new IndexResponse( + shardId, + "id", + 0, 0, 0, true + ))).toArray(BulkItemResponse[]::new); + + listener.onResponse(new BulkShardResponse(shardId, bulkItemResponses)); return null; - }).when(nodeClient).executeLocally(eq(TransportShardBulkAction.TYPE), any(BulkShardRequest.class), any()); + }); } - private void expectInferenceRequest(Map> modelsAndInferenceTextMap) { + private boolean matchBulkShardRequest(ActionRequest request, int requestSize) { + return (request instanceof BulkShardRequest) && ((BulkShardRequest) request).items().length == requestSize; + } + + @SuppressWarnings("unchecked") + private void expectInferenceRequest(String modelId, String... inferenceTexts) { doAnswer(invocation -> { InferenceAction.Request request = (InferenceAction.Request) invocation.getArguments()[1]; - Set textsForModel = modelsAndInferenceTextMap.get(request.getModelId()); - assertThat("model is not expected", textsForModel, notNullValue()); - assertThat("unexpected inference field values", request.getInput(), containsInAnyOrder(textsForModel.toArray())); - - @SuppressWarnings("unchecked") var listener = (ActionListener) invocation.getArguments()[2]; listener.onResponse( new InferenceAction.Response( @@ -262,10 +313,22 @@ private void expectInferenceRequest(Map> modelsAndInferenceT ) ); return Void.TYPE; - }).when(nodeClient).execute(eq(InferenceAction.INSTANCE), argThat(r -> inferenceRequestMatches(r, modelsAndInferenceTextMap.keySet())), any()); + }).when(nodeClient).execute(eq(InferenceAction.INSTANCE), argThat(r -> inferenceRequestMatches(r, modelId, inferenceTexts)), any()); + } + + private void expectInferenceRequestFails(String modelId, String... inferenceTexts) { + doAnswer(invocation -> { + @SuppressWarnings("unchecked") + var listener = (ActionListener) invocation.getArguments()[2]; + listener.onFailure(new Exception("Inference failed")); + return Void.TYPE; + }).when(nodeClient).execute(eq(InferenceAction.INSTANCE), argThat(r -> inferenceRequestMatches(r, modelId, inferenceTexts)), any()); } - private boolean inferenceRequestMatches(ActionRequest request, Set models) { - return request instanceof InferenceAction.Request && models.contains(((InferenceAction.Request) request).getModelId()); + private boolean inferenceRequestMatches(ActionRequest request, String modelId, String[] inferenceTexts) { + if (request instanceof InferenceAction.Request inferenceRequest) { + return inferenceRequest.getModelId().equals(modelId) && inferenceRequest.getInput().containsAll(List.of(inferenceTexts)); + } + return false; } }