diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java index b7a6387045e3d..452a9ec90443a 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkOperation.java @@ -209,8 +209,8 @@ private void executeBulkRequestsByShard(Map> requ requests.toArray(new BulkItemRequest[0]) ); var indexMetadata = clusterState.getMetadata().index(shardId.getIndexName()); - if (indexMetadata != null && indexMetadata.getFieldsForModels().isEmpty() == false) { - bulkShardRequest.setFieldInferenceMetadata(indexMetadata.getFieldsForModels()); + if (indexMetadata != null && indexMetadata.getFieldInferenceMetadata().isEmpty() == false) { + bulkShardRequest.setFieldInferenceMetadata(indexMetadata.getFieldInferenceMetadata()); } bulkShardRequest.waitForActiveShards(bulkRequest.waitForActiveShards()); bulkShardRequest.timeout(bulkRequest.timeout()); diff --git a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java index 1b5494c6a68f5..39fa791a3e27d 100644 --- a/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java +++ b/server/src/main/java/org/elasticsearch/action/bulk/BulkShardRequest.java @@ -15,6 +15,7 @@ import org.elasticsearch.action.support.replication.ReplicatedWriteRequest; import org.elasticsearch.action.support.replication.ReplicationRequest; import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.common.io.stream.StreamInput; import org.elasticsearch.common.io.stream.StreamOutput; import org.elasticsearch.common.util.set.Sets; @@ -22,7 +23,6 @@ import org.elasticsearch.transport.RawIndexingDataTransportRequest; import java.io.IOException; -import java.util.Map; import java.util.Set; public final class BulkShardRequest extends ReplicatedWriteRequest @@ -34,7 +34,7 @@ public final class BulkShardRequest extends ReplicatedWriteRequest> fieldsInferenceMetadata = null; + private transient FieldInferenceMetadata fieldsInferenceMetadataMap = null; public BulkShardRequest(StreamInput in) throws IOException { super(in); @@ -51,24 +51,24 @@ public BulkShardRequest(ShardId shardId, RefreshPolicy refreshPolicy, BulkItemRe * Public for test * Set the transient metadata indicating that this request requires running inference before proceeding. */ - public void setFieldInferenceMetadata(Map> fieldsInferenceMetadata) { - this.fieldsInferenceMetadata = fieldsInferenceMetadata; + public void setFieldInferenceMetadata(FieldInferenceMetadata fieldsInferenceMetadata) { + this.fieldsInferenceMetadataMap = fieldsInferenceMetadata; } /** * Consumes the inference metadata to execute inference on the bulk items just once. */ - public Map> consumeFieldInferenceMetadata() { - var ret = fieldsInferenceMetadata; - fieldsInferenceMetadata = null; + public FieldInferenceMetadata consumeFieldInferenceMetadata() { + FieldInferenceMetadata ret = fieldsInferenceMetadataMap; + fieldsInferenceMetadataMap = null; return ret; } /** * Public for test */ - public Map> getFieldsInferenceMetadata() { - return fieldsInferenceMetadata; + public FieldInferenceMetadata getFieldsInferenceMetadataMap() { + return fieldsInferenceMetadataMap; } public long totalSizeInBytes() { diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java index e679d3c970abf..984a20419b2c8 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilter.java @@ -24,6 +24,7 @@ import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.RefCountingRunnable; import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.common.util.concurrent.AtomicArray; import org.elasticsearch.common.xcontent.support.XContentMapValues; import org.elasticsearch.core.Nullable; @@ -44,7 +45,6 @@ import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.stream.Collectors; /** @@ -81,7 +81,7 @@ public void app case TransportShardBulkAction.ACTION_NAME: BulkShardRequest bulkShardRequest = (BulkShardRequest) request; var fieldInferenceMetadata = bulkShardRequest.consumeFieldInferenceMetadata(); - if (fieldInferenceMetadata != null && fieldInferenceMetadata.size() > 0) { + if (fieldInferenceMetadata != null && fieldInferenceMetadata.isEmpty() == false) { Runnable onInferenceCompletion = () -> chain.proceed(task, action, request, listener); processBulkShardRequest(fieldInferenceMetadata, bulkShardRequest, onInferenceCompletion); } else { @@ -96,7 +96,7 @@ public void app } private void processBulkShardRequest( - Map> fieldInferenceMetadata, + FieldInferenceMetadata fieldInferenceMetadata, BulkShardRequest bulkShardRequest, Runnable onCompletion ) { @@ -112,13 +112,13 @@ private record FieldInferenceResponse(String field, Model model, ChunkedInferenc private record FieldInferenceResponseAccumulator(int id, List responses, List failures) {} private class AsyncBulkShardInferenceAction implements Runnable { - private final Map> fieldInferenceMetadata; + private final FieldInferenceMetadata fieldInferenceMetadata; private final BulkShardRequest bulkShardRequest; private final Runnable onCompletion; private final AtomicArray inferenceResults; private AsyncBulkShardInferenceAction( - Map> fieldInferenceMetadata, + FieldInferenceMetadata fieldInferenceMetadata, BulkShardRequest bulkShardRequest, Runnable onCompletion ) { @@ -289,39 +289,35 @@ private Map> createFieldInferenceRequests(Bu continue; } final Map docMap = indexRequest.sourceAsMap(); - for (var entry : fieldInferenceMetadata.entrySet()) { - String inferenceId = entry.getKey(); - for (var field : entry.getValue()) { - var value = XContentMapValues.extractValue(field, docMap); - if (value == null) { - continue; - } - if (inferenceResults.get(item.id()) == null) { - inferenceResults.set( + for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { + String field = entry.getKey(); + String inferenceId = entry.getValue().inferenceId(); + var value = XContentMapValues.extractValue(field, docMap); + if (value == null) { + continue; + } + if (inferenceResults.get(item.id()) == null) { + inferenceResults.set( + item.id(), + new FieldInferenceResponseAccumulator( item.id(), - new FieldInferenceResponseAccumulator( - item.id(), - Collections.synchronizedList(new ArrayList<>()), - Collections.synchronizedList(new ArrayList<>()) - ) - ); - } - if (value instanceof String valueStr) { - List fieldRequests = fieldRequestsMap.computeIfAbsent( - inferenceId, - k -> new ArrayList<>() - ); - fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); - } else { - inferenceResults.get(item.id()).failures.add( - new ElasticsearchStatusException( - "Invalid format for field [{}], expected [String] got [{}]", - RestStatus.BAD_REQUEST, - field, - value.getClass().getSimpleName() - ) - ); - } + Collections.synchronizedList(new ArrayList<>()), + Collections.synchronizedList(new ArrayList<>()) + ) + ); + } + if (value instanceof String valueStr) { + List fieldRequests = fieldRequestsMap.computeIfAbsent(inferenceId, k -> new ArrayList<>()); + fieldRequests.add(new FieldInferenceRequest(item.id(), field, valueStr)); + } else { + inferenceResults.get(item.id()).failures.add( + new ElasticsearchStatusException( + "Invalid format for field [{}], expected [String] got [{}]", + RestStatus.BAD_REQUEST, + field, + value.getClass().getSimpleName() + ) + ); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java index 7f3ffbe596543..4a1825303b5a7 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/action/filter/ShardBulkInferenceActionFilterTests.java @@ -16,6 +16,7 @@ import org.elasticsearch.action.index.IndexRequest; import org.elasticsearch.action.support.ActionFilterChain; import org.elasticsearch.action.support.WriteRequest; +import org.elasticsearch.cluster.metadata.FieldInferenceMetadata; import org.elasticsearch.common.Strings; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.shard.ShardId; @@ -40,7 +41,6 @@ import java.util.ArrayList; import java.util.HashMap; -import java.util.HashSet; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; @@ -79,7 +79,7 @@ public void testFilterNoop() throws Exception { CountDownLatch chainExecuted = new CountDownLatch(1); ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { - assertNull(((BulkShardRequest) request).getFieldsInferenceMetadata()); + assertNull(((BulkShardRequest) request).getFieldsInferenceMetadataMap()); } finally { chainExecuted.countDown(); } @@ -91,7 +91,9 @@ public void testFilterNoop() throws Exception { WriteRequest.RefreshPolicy.NONE, new BulkItemRequest[0] ); - request.setFieldInferenceMetadata(Map.of("foo", Set.of("bar"))); + request.setFieldInferenceMetadata( + new FieldInferenceMetadata(Map.of("foo", new FieldInferenceMetadata.FieldInferenceOptions("bar", Set.of()))) + ); filter.apply(task, TransportShardBulkAction.ACTION_NAME, request, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } @@ -104,7 +106,7 @@ public void testInferenceNotFound() throws Exception { ActionFilterChain actionFilterChain = (task, action, request, listener) -> { try { BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertNull(bulkShardRequest.getFieldsInferenceMetadata()); + assertNull(bulkShardRequest.getFieldsInferenceMetadataMap()); for (BulkItemRequest item : bulkShardRequest.items()) { assertNotNull(item.getPrimaryResponse()); assertTrue(item.getPrimaryResponse().isFailed()); @@ -118,11 +120,15 @@ public void testInferenceNotFound() throws Exception { ActionListener actionListener = mock(ActionListener.class); Task task = mock(Task.class); - Map> inferenceFields = Map.of( - model.getInferenceEntityId(), - Set.of("field1"), - "inference_0", - Set.of("field2", "field3") + FieldInferenceMetadata inferenceFields = new FieldInferenceMetadata( + Map.of( + "field1", + new FieldInferenceMetadata.FieldInferenceOptions(model.getInferenceEntityId(), Set.of()), + "field2", + new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()), + "field3", + new FieldInferenceMetadata.FieldInferenceOptions("inference_0", Set.of()) + ) ); BulkItemRequest[] items = new BulkItemRequest[10]; for (int i = 0; i < items.length; i++) { @@ -144,19 +150,19 @@ public void testManyRandomDocs() throws Exception { } int numInferenceFields = randomIntBetween(1, 5); - Map> inferenceFields = new HashMap<>(); + Map inferenceFieldsMap = new HashMap<>(); for (int i = 0; i < numInferenceFields; i++) { - String inferenceId = randomFrom(inferenceModelMap.keySet()); String field = randomAlphaOfLengthBetween(5, 10); - var res = inferenceFields.computeIfAbsent(inferenceId, k -> new HashSet<>()); - res.add(field); + String inferenceId = randomFrom(inferenceModelMap.keySet()); + inferenceFieldsMap.put(field, new FieldInferenceMetadata.FieldInferenceOptions(inferenceId, Set.of())); } + FieldInferenceMetadata fieldInferenceMetadata = new FieldInferenceMetadata(inferenceFieldsMap); int numRequests = randomIntBetween(100, 1000); BulkItemRequest[] originalRequests = new BulkItemRequest[numRequests]; BulkItemRequest[] modifiedRequests = new BulkItemRequest[numRequests]; for (int id = 0; id < numRequests; id++) { - BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, inferenceFields); + BulkItemRequest[] res = randomBulkItemRequest(id, inferenceModelMap, fieldInferenceMetadata); originalRequests[id] = res[0]; modifiedRequests[id] = res[1]; } @@ -167,7 +173,7 @@ public void testManyRandomDocs() throws Exception { try { assertThat(request, instanceOf(BulkShardRequest.class)); BulkShardRequest bulkShardRequest = (BulkShardRequest) request; - assertNull(bulkShardRequest.getFieldsInferenceMetadata()); + assertNull(bulkShardRequest.getFieldsInferenceMetadataMap()); BulkItemRequest[] items = bulkShardRequest.items(); assertThat(items.length, equalTo(originalRequests.length)); for (int id = 0; id < items.length; id++) { @@ -186,7 +192,7 @@ public void testManyRandomDocs() throws Exception { ActionListener actionListener = mock(ActionListener.class); Task task = mock(Task.class); BulkShardRequest original = new BulkShardRequest(new ShardId("test", "test", 0), WriteRequest.RefreshPolicy.NONE, originalRequests); - original.setFieldInferenceMetadata(inferenceFields); + original.setFieldInferenceMetadata(fieldInferenceMetadata); filter.apply(task, TransportShardBulkAction.ACTION_NAME, original, actionListener, actionFilterChain); awaitLatch(chainExecuted, 10, TimeUnit.SECONDS); } @@ -257,42 +263,40 @@ private static ShardBulkInferenceActionFilter createFilter(ThreadPool threadPool private static BulkItemRequest[] randomBulkItemRequest( int id, Map modelMap, - Map> inferenceFieldMap + FieldInferenceMetadata fieldInferenceMetadata ) { Map docMap = new LinkedHashMap<>(); Map inferenceResultsMap = new LinkedHashMap<>(); - for (var entry : inferenceFieldMap.entrySet()) { - String inferenceId = entry.getKey(); - var model = modelMap.get(inferenceId); - for (var field : entry.getValue()) { - String text = randomAlphaOfLengthBetween(10, 100); - docMap.put(field, text); - if (model == null) { - // ignore results, the doc should fail with a resource not found exception - continue; - } - int numChunks = randomIntBetween(1, 5); - List chunks = new ArrayList<>(); - for (int i = 0; i < numChunks; i++) { - chunks.add(randomAlphaOfLengthBetween(5, 10)); - } - TaskType taskType = model.getTaskType(); - final ChunkedInferenceServiceResults results; - switch (taskType) { - case TEXT_EMBEDDING: - results = randomTextEmbeddings(chunks); - break; + for (var entry : fieldInferenceMetadata.getFieldInferenceOptions().entrySet()) { + String field = entry.getKey(); + var model = modelMap.get(entry.getValue().inferenceId()); + String text = randomAlphaOfLengthBetween(10, 100); + docMap.put(field, text); + if (model == null) { + // ignore results, the doc should fail with a resource not found exception + continue; + } + int numChunks = randomIntBetween(1, 5); + List chunks = new ArrayList<>(); + for (int i = 0; i < numChunks; i++) { + chunks.add(randomAlphaOfLengthBetween(5, 10)); + } + TaskType taskType = model.getTaskType(); + final ChunkedInferenceServiceResults results; + switch (taskType) { + case TEXT_EMBEDDING: + results = randomTextEmbeddings(chunks); + break; - case SPARSE_EMBEDDING: - results = randomSparseEmbeddings(chunks); - break; + case SPARSE_EMBEDDING: + results = randomSparseEmbeddings(chunks); + break; - default: - throw new AssertionError("Unknown task type " + taskType.name()); - } - model.putResult(text, results); - InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); + default: + throw new AssertionError("Unknown task type " + taskType.name()); } + model.putResult(text, results); + InferenceResultFieldMapper.applyFieldInference(inferenceResultsMap, field, model, results); } Map expectedDocMap = new LinkedHashMap<>(docMap); expectedDocMap.put(InferenceResultFieldMapper.NAME, inferenceResultsMap);