diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index da52bc8b07a09..eaa62a3aa743a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -329,6 +329,7 @@ private void onFinish() { inferenceProvider.service() .chunkedInfer( inferenceProvider.model(), + null, inputs, Map.of(), InputType.INGEST, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 2ddbc585a1a3b..7cfaeaae4c3a4 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -183,7 +183,7 @@ public void testManyRandomDocs() throws Exception { IndexRequest actualRequest = getIndexRequestOrNull(items[id].request()); IndexRequest expectedRequest = getIndexRequestOrNull(modifiedRequests[id].request()); try { - assertToXContentEquivalent(expectedRequest.source(), actualRequest.source(), actualRequest.getContentType()); + assertToXContentEquivalent(expectedRequest.source(), actualRequest.source(), expectedRequest.getContentType()); } catch (Exception exc) { throw new IllegalStateException(exc); } @@ -228,9 +228,9 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool InferenceService inferenceService = mock(InferenceService.class); Answer chunkedInferAnswer = invocationOnMock -> { StaticModel model = (StaticModel) invocationOnMock.getArguments()[0]; - List inputs = (List) invocationOnMock.getArguments()[1]; + List inputs = (List) invocationOnMock.getArguments()[2]; ActionListener> listener = (ActionListener< - List>) invocationOnMock.getArguments()[5]; + List>) invocationOnMock.getArguments()[6]; Runnable runnable = () -> { List results = new ArrayList<>(); for (String input : inputs) { @@ -249,7 +249,7 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool } return null; }; - doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any()); + doAnswer(chunkedInferAnswer).when(inferenceService).chunkedInfer(any(), any(), any(), any(), any(), any(), any()); Answer modelAnswer = invocationOnMock -> { String inferenceId = (String) invocationOnMock.getArguments()[0]; @@ -270,6 +270,7 @@ private static BulkItemRequest[] randomBulkItemRequest( ) { Map docMap = new LinkedHashMap<>(); Map expectedDocMap = new LinkedHashMap<>(); + XContentType requestContentType = randomFrom(XContentType.values()); for (var entry : fieldInferenceMap.values()) { String field = entry.getName(); var model = modelMap.get(entry.getInferenceId()); @@ -280,11 +281,10 @@ private static BulkItemRequest[] randomBulkItemRequest( // ignore results, the doc should fail with a resource not found exception continue; } - var result = randomSemanticText(field, model, List.of(text), randomFrom(XContentType.values())); + var result = randomSemanticText(field, model, List.of(text), requestContentType); model.putResult(text, result); expectedDocMap.put(field, result); } - XContentType requestContentType = randomFrom(XContentType.values()); return new BulkItemRequest[] { new BulkItemRequest(id, new IndexRequest("index").source(docMap, requestContentType)), new BulkItemRequest(id, new IndexRequest("index").source(expectedDocMap, requestContentType)) };